@novastera-oss/llamarn 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. package/android/src/main/cpp/include/llama.h +134 -36
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +2 -2
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +30 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +50 -40
  26. package/cpp/llama.cpp/common/common.h +5 -2
  27. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  28. package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  30. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  35. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  70. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  84. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  101. package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
  102. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  103. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  104. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  105. package/cpp/llama.cpp/include/llama.h +134 -36
  106. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  107. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  108. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  109. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  110. package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
  111. package/cpp/llama.cpp/src/llama-batch.h +36 -11
  112. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  113. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  114. package/cpp/llama.cpp/src/llama-context.cpp +313 -213
  115. package/cpp/llama.cpp/src/llama-context.h +16 -12
  116. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  117. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  118. package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
  119. package/cpp/llama.cpp/src/llama-graph.h +90 -34
  120. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  121. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  122. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
  123. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  124. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
  125. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
  126. package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
  127. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  128. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  129. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
  130. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
  131. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  132. package/cpp/llama.cpp/src/llama-memory.h +64 -23
  133. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  134. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  135. package/cpp/llama.cpp/src/llama-model.cpp +726 -141
  136. package/cpp/llama.cpp/src/llama-model.h +4 -0
  137. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  138. package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
  139. package/cpp/llama.cpp/src/llama.cpp +11 -7
  140. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  141. package/cpp/rn-completion.cpp +2 -2
  142. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  143. package/ios/include/chat.h +1 -1
  144. package/ios/include/common.h +5 -2
  145. package/ios/include/llama.h +134 -36
  146. package/ios/libs/llama.xcframework/Info.plist +18 -18
  147. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  148. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  149. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
  150. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  151. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  152. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  153. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  154. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  155. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
  160. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
  161. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  162. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  165. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  167. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
  168. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  173. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  175. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  178. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/package.json +1 -2
  184. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  185. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  186. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  187. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  188. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  189. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  190. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  191. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  192. /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
@@ -1,6 +1,7 @@
1
- #include "llama-kv-cache-recurrent.h"
1
+ #include "llama-memory-recurrent.h"
2
2
 
3
3
  #include "llama-impl.h"
4
+ #include "llama-io.h"
4
5
  #include "llama-batch.h"
5
6
  #include "llama-model.h"
6
7
 
@@ -11,27 +12,28 @@
11
12
  #include <stdexcept>
12
13
 
13
14
  //
14
- // llama_kv_cache_recurrent
15
+ // llama_memory_recurrent
15
16
  //
16
17
 
17
- llama_kv_cache_recurrent::llama_kv_cache_recurrent(
18
- const llama_model & model,
19
- ggml_type type_k,
20
- ggml_type type_v,
21
- bool offload,
22
- uint32_t kv_size,
23
- uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
18
+ llama_memory_recurrent::llama_memory_recurrent(
19
+ const llama_model & model,
20
+ layer_filter_cb && filter,
21
+ ggml_type type_r,
22
+ ggml_type type_s,
23
+ bool offload,
24
+ uint32_t mem_size,
25
+ uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
24
26
  const int32_t n_layer = hparams.n_layer;
25
27
 
26
- LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
27
- __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
28
+ LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
29
+ __func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
28
30
 
29
31
  head = 0;
30
- size = kv_size;
32
+ size = mem_size;
31
33
  used = 0;
32
34
 
33
35
  cells.clear();
34
- cells.resize(kv_size);
36
+ cells.resize(mem_size);
35
37
 
36
38
  // create a context for each buffer type
37
39
  std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@@ -58,12 +60,14 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
58
60
  return it->second;
59
61
  };
60
62
 
61
- k_l.reserve(n_layer);
62
- v_l.reserve(n_layer);
63
+ r_l.resize(n_layer);
64
+ s_l.resize(n_layer);
63
65
 
64
66
  for (int i = 0; i < n_layer; i++) {
65
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
66
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
67
+ if (filter && !filter(i)) {
68
+ LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
69
+ continue;
70
+ }
67
71
 
68
72
  const char * dev_name = "CPU";
69
73
 
@@ -83,12 +87,12 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
83
87
  throw std::runtime_error("failed to create ggml context for kv cache");
84
88
  }
85
89
 
86
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
87
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
88
- ggml_format_name(k, "cache_k_l%d", i);
89
- ggml_format_name(v, "cache_v_l%d", i);
90
- k_l.push_back(k);
91
- v_l.push_back(v);
90
+ ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
91
+ ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);
92
+ ggml_format_name(r, "cache_r_l%d", i);
93
+ ggml_format_name(s, "cache_s_l%d", i);
94
+ r_l[i] = r;
95
+ s_l[i] = s;
92
96
  }
93
97
 
94
98
  // allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -106,32 +110,35 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
106
110
  }
107
111
 
108
112
  {
109
- const size_t memory_size_k = size_k_bytes();
110
- const size_t memory_size_v = size_v_bytes();
113
+ const size_t memory_size_r = size_r_bytes();
114
+ const size_t memory_size_s = size_s_bytes();
111
115
 
112
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
113
- (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
114
- ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
115
- ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
116
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
117
+ (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
118
+ ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
119
+ ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
116
120
  }
117
121
  }
118
122
 
119
- void llama_kv_cache_recurrent::clear() {
123
+ void llama_memory_recurrent::clear(bool data) {
120
124
  for (int32_t i = 0; i < (int32_t) size; ++i) {
121
125
  cells[i].pos = -1;
122
126
  cells[i].seq_id.clear();
123
127
  cells[i].src = -1;
124
128
  cells[i].tail = -1;
125
129
  }
130
+
126
131
  head = 0;
127
132
  used = 0;
128
133
 
129
- for (auto & buf : bufs) {
130
- ggml_backend_buffer_clear(buf.get(), 0);
134
+ if (data) {
135
+ for (auto & buf : bufs) {
136
+ ggml_backend_buffer_clear(buf.get(), 0);
137
+ }
131
138
  }
132
139
  }
133
140
 
134
- bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
141
+ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
135
142
  uint32_t new_head = size;
136
143
 
137
144
  if (p0 < 0) {
@@ -150,7 +157,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
150
157
  if (0 <= seq_id) {
151
158
  int32_t & tail_id = cells[seq_id].tail;
152
159
  if (tail_id >= 0) {
153
- const kv_cell & cell = cells[tail_id];
160
+ const auto & cell = cells[tail_id];
154
161
  // partial intersection is invalid
155
162
  if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
156
163
  return false;
@@ -198,7 +205,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
198
205
  return true;
199
206
  }
200
207
 
201
- void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
208
+ void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
202
209
  if (seq_id_src == seq_id_dst) {
203
210
  return;
204
211
  }
@@ -212,11 +219,11 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
212
219
  }
213
220
 
214
221
  if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
215
- kv_cell & tail_src = cells[seq_id_src];
216
- kv_cell & tail_dst = cells[seq_id_dst];
222
+ auto & tail_src = cells[seq_id_src];
223
+ auto & tail_dst = cells[seq_id_dst];
217
224
  if (tail_dst.tail >= 0) {
218
225
  // clear destination seq_id if it wasn't empty
219
- kv_cell & cell_dst = cells[tail_dst.tail];
226
+ auto & cell_dst = cells[tail_dst.tail];
220
227
 
221
228
  cell_dst.seq_id.erase(seq_id_dst);
222
229
  tail_dst.tail = -1;
@@ -227,7 +234,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
227
234
  }
228
235
  }
229
236
  if (tail_src.tail >= 0) {
230
- kv_cell & cell_src = cells[tail_src.tail];
237
+ auto & cell_src = cells[tail_src.tail];
231
238
 
232
239
  cell_src.seq_id.insert(seq_id_dst);
233
240
  tail_dst.tail = tail_src.tail;
@@ -235,7 +242,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
235
242
  }
236
243
  }
237
244
 
238
- void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
245
+ void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) {
239
246
  uint32_t new_head = size;
240
247
 
241
248
  for (uint32_t i = 0; i < size; ++i) {
@@ -267,7 +274,7 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
267
274
  }
268
275
  }
269
276
 
270
- void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
277
+ void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
271
278
  if (shift == 0) {
272
279
  return;
273
280
  }
@@ -289,7 +296,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
289
296
  if (0 <= seq_id && seq_id < (int64_t) size) {
290
297
  const int32_t tail_id = cells[seq_id].tail;
291
298
  if (tail_id >= 0) {
292
- kv_cell & cell = cells[tail_id];
299
+ auto & cell = cells[tail_id];
293
300
  if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
294
301
  cell.pos += shift;
295
302
  }
@@ -297,7 +304,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
297
304
  }
298
305
  }
299
306
 
300
- void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
307
+ void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
301
308
  if (d == 1) {
302
309
  return;
303
310
  }
@@ -319,7 +326,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
319
326
  if (0 <= seq_id && seq_id < (int64_t) size) {
320
327
  const int32_t tail_id = cells[seq_id].tail;
321
328
  if (tail_id >= 0) {
322
- kv_cell & cell = cells[tail_id];
329
+ auto & cell = cells[tail_id];
323
330
  if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
324
331
  cell.pos /= d;
325
332
  }
@@ -327,7 +334,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
327
334
  }
328
335
  }
329
336
 
330
- llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
337
+ llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const {
331
338
  llama_pos result = std::numeric_limits<llama_pos>::max();
332
339
 
333
340
  for (uint32_t i = 0; i < size; ++i) {
@@ -343,7 +350,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
343
350
  return result;
344
351
  }
345
352
 
346
- llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
353
+ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
347
354
  llama_pos result = -1;
348
355
 
349
356
  for (uint32_t i = 0; i < size; ++i) {
@@ -355,18 +362,16 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
355
362
  return result;
356
363
  }
357
364
 
358
- llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
359
- GGML_UNUSED(embd_pooled);
360
-
361
- auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
365
+ llama_memory_state_ptr llama_memory_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
366
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
362
367
 
363
368
  std::vector<llama_ubatch> ubatches;
364
369
 
365
370
  while (sbatch.n_tokens > 0) {
366
371
  llama_ubatch ubatch;
367
372
 
368
- if (embd_pooled) {
369
- // Pooled embeddings cannot be split across ubatches (yet)
373
+ if (embd_all) {
374
+ // if all tokens are output, split by sequence
370
375
  ubatch = sbatch.split_seq(n_ubatch);
371
376
  } else {
372
377
  ubatch = sbatch.split_equal(n_ubatch);
@@ -376,17 +381,24 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch &
376
381
  }
377
382
 
378
383
  if (!prepare(ubatches)) {
379
- return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
384
+ return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
380
385
  }
381
386
 
382
- return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
387
+ return std::make_unique<llama_memory_recurrent_state>(this, std::move(sbatch), std::move(ubatches));
383
388
  }
384
389
 
385
- llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
386
- return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
390
+ llama_memory_state_ptr llama_memory_recurrent::init_full() {
391
+ return std::make_unique<llama_memory_recurrent_state>(this);
387
392
  }
388
393
 
389
- bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
394
+ llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
395
+ GGML_UNUSED(lctx);
396
+ GGML_UNUSED(optimize);
397
+
398
+ return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
399
+ }
400
+
401
+ bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
390
402
  // simply remember the full state because it is very small for this type of cache
391
403
  // TODO: optimize
392
404
  auto org_cells = cells;
@@ -395,21 +407,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
395
407
 
396
408
  bool success = true;
397
409
 
398
- // TODO: here we have to verify that all ubatches can fit in the cells
399
- // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
400
- // during the compute of each ubatch. to reproduce, uncomment the following loop and run:
401
- //
402
- // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
403
- //
404
- // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
405
- //
406
- GGML_UNUSED(ubatches);
407
- //for (const auto & ubatch : ubatches) {
408
- // if (!find_slot(ubatch)) {
409
- // success = false;
410
- // break;
411
- // }
412
- //}
410
+ for (const auto & ubatch : ubatches) {
411
+ if (!find_slot(ubatch)) {
412
+ success = false;
413
+ break;
414
+ }
415
+ }
413
416
 
414
417
  // restore the original state
415
418
  cells = std::move(org_cells);
@@ -419,26 +422,14 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
419
422
  return success;
420
423
  }
421
424
 
422
- bool llama_kv_cache_recurrent::update(llama_context & lctx) {
423
- GGML_UNUSED(lctx);
424
- // noop
425
- return false;
426
- }
427
-
428
- void llama_kv_cache_recurrent::defrag_sched(float thold) {
429
- GGML_UNUSED(thold);
430
- // noop
431
- }
432
-
433
- bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
434
- const uint32_t n_tokens = ubatch.n_tokens;
435
- const uint32_t n_seqs = ubatch.n_seqs;
425
+ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
426
+ const uint32_t n_seqs = ubatch.n_seqs;
436
427
 
437
428
  const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
438
429
 
439
430
  // if we have enough unused cells before the current head ->
440
431
  // better to start searching from the beginning of the cache, hoping to fill it
441
- if (head > used + 2*n_tokens) {
432
+ if (head > used + 2*n_seqs) {
442
433
  head = 0;
443
434
  }
444
435
 
@@ -465,9 +456,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
465
456
  return false;
466
457
  }
467
458
  if (j > 0) {
468
- kv_cell & seq = cells[seq_id];
459
+ auto & seq = cells[seq_id];
469
460
  if (seq.tail >= 0) {
470
- kv_cell & cell = cells[seq.tail];
461
+ auto & cell = cells[seq.tail];
471
462
  // clear cells from seq_ids that become shared
472
463
  // (should not normally happen, but let's handle it anyway)
473
464
  cell.seq_id.erase(seq_id);
@@ -487,7 +478,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
487
478
  std::vector<int32_t> tails_verif;
488
479
  tails_verif.assign(size, -1);
489
480
  for (uint32_t i = 0; i < size; ++i) {
490
- kv_cell & cell = cells[i];
481
+ auto & cell = cells[i];
491
482
  for (llama_seq_id seq_id : cell.seq_id) {
492
483
  if (tails_verif[seq_id] != -1) {
493
484
  LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
@@ -508,7 +499,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
508
499
 
509
500
  for (uint32_t i = 0; i < size; ++i) {
510
501
  if (next_empty_cell >= size) { next_empty_cell -= size; }
511
- kv_cell & cell = cells[next_empty_cell];
502
+ auto & cell = cells[next_empty_cell];
512
503
  if (cell.is_empty()) { break; }
513
504
  next_empty_cell += 1;
514
505
  }
@@ -516,34 +507,34 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
516
507
  // find usable cell range
517
508
  for (uint32_t s = 0; s < n_seqs; ++s) {
518
509
  const llama_seq_id seq_id = ubatch.seq_id[s][0];
519
- kv_cell & seq_meta = cells[seq_id];
510
+ auto & seq_meta = cells[seq_id];
520
511
  bool has_cell = false;
521
512
  if (seq_meta.tail >= 0) {
522
- kv_cell & cell = cells[seq_meta.tail];
513
+ auto & cell = cells[seq_meta.tail];
523
514
  GGML_ASSERT(cell.has_seq_id(seq_id));
524
515
  // does this seq_id "own" the cell?
525
516
  if (cell.seq_id.size() == 1) { has_cell = true; }
526
517
  }
527
518
  if (!has_cell) {
528
- kv_cell & empty_cell = cells[next_empty_cell];
519
+ auto & empty_cell = cells[next_empty_cell];
529
520
  GGML_ASSERT(empty_cell.is_empty());
530
521
  // copy old tail into the empty cell
531
522
  if (seq_meta.tail >= 0) {
532
- kv_cell & orig_cell = cells[seq_meta.tail];
523
+ auto & orig_cell = cells[seq_meta.tail];
533
524
  empty_cell.pos = orig_cell.pos;
534
525
  empty_cell.src = orig_cell.src;
535
526
  orig_cell.seq_id.erase(seq_id);
536
527
  empty_cell.seq_id.insert(seq_id); // will be overwritten
528
+ GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
537
529
  }
538
530
  seq_meta.tail = next_empty_cell;
539
531
  // find next empty cell
540
532
  if (s + 1 < n_seqs) {
541
- next_empty_cell += 1;
542
533
  for (uint32_t i = 0; i < size; ++i) {
534
+ next_empty_cell += 1;
543
535
  if (next_empty_cell >= size) { next_empty_cell -= size; }
544
- kv_cell & cell = cells[next_empty_cell];
536
+ auto & cell = cells[next_empty_cell];
545
537
  if (cell.is_empty()) { break; }
546
- next_empty_cell += 1;
547
538
  }
548
539
  }
549
540
  }
@@ -553,22 +544,24 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
553
544
 
554
545
  // gather and re-order
555
546
  for (uint32_t s = 0; s < n_seqs; ++s) {
556
- int32_t dst_id = s + min;
557
- int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
547
+ const int32_t dst_id = s + min;
548
+ const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
558
549
  if (dst_id != src_id) {
559
- kv_cell & dst_cell = cells[dst_id];
560
- kv_cell & src_cell = cells[src_id];
550
+ auto & dst_cell = cells[dst_id];
551
+ auto & src_cell = cells[src_id];
561
552
 
562
553
  std::swap(dst_cell.pos, src_cell.pos);
563
554
  std::swap(dst_cell.src, src_cell.src);
564
555
  std::swap(dst_cell.seq_id, src_cell.seq_id);
565
556
 
566
- // swap tails (assuming they NEVER overlap)
567
- for (const llama_seq_id seq_id : src_cell.seq_id) {
568
- cells[seq_id].tail = src_id;
569
- }
570
- for (const llama_seq_id seq_id : dst_cell.seq_id) {
571
- cells[seq_id].tail = dst_id;
557
+ // swap tails
558
+ for (uint32_t i = 0; i < size; ++i) {
559
+ int32_t & tail = cells[i].tail;
560
+ if (tail == src_id) {
561
+ tail = dst_id;
562
+ } else if (tail == dst_id) {
563
+ tail = src_id;
564
+ }
572
565
  }
573
566
  }
574
567
  }
@@ -576,8 +569,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
576
569
  // update the pos of the used seqs
577
570
  for (uint32_t s = 0; s < n_seqs; ++s) {
578
571
  const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
579
- int32_t cell_id = s + min;
580
- kv_cell & cell = cells[cell_id];
572
+ const int32_t cell_id = s + min;
573
+ auto & cell = cells[cell_id];
581
574
 
582
575
  if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
583
576
  // What should happen when the pos backtracks or skips a value?
@@ -594,61 +587,54 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
594
587
  }
595
588
  }
596
589
 
590
+ // Find first cell without src refs, to use as the zero-ed state
591
+ {
592
+ // TODO: bake-in src refcounts in the cell metadata
593
+ std::vector<int32_t> refcounts(size, 0);
594
+ for (size_t i = 0; i < size; ++i) {
595
+ const int32_t src = cells[i].src;
596
+ if (src >= 0) {
597
+ refcounts[src] += 1;
598
+ }
599
+ }
600
+
601
+ rs_z = -1;
602
+ for (int i = min; i <= max; ++i) {
603
+ if (refcounts[i] == 0) {
604
+ rs_z = i;
605
+ break;
606
+ }
607
+ }
608
+
609
+ for (int i = min; i <= max; ++i) {
610
+ if (cells[i].src < 0) {
611
+ GGML_ASSERT(rs_z >= 0);
612
+ cells[i].src0 = rs_z;
613
+ } else {
614
+ // Stage the source ids for all used cells to allow correct seq_* behavior
615
+ // and still make these values available when setting the inputs
616
+ cells[i].src0 = cells[i].src;
617
+ }
618
+ cells[i].src = i; // avoid moving or clearing twice
619
+ }
620
+ }
621
+
597
622
  // allow getting the range of used cells, from head to head + n
598
623
  head = min;
599
624
  n = max - min + 1;
600
625
  used = std::count_if(cells.begin(), cells.end(),
601
- [](const kv_cell & cell){ return !cell.is_empty(); });
626
+ [](const mem_cell & cell){ return !cell.is_empty(); });
602
627
 
603
628
  // sanity check
604
629
  return n >= n_seqs;
605
630
  }
606
631
 
607
- bool llama_kv_cache_recurrent::get_can_shift() const {
608
- return false;
609
- }
610
-
611
- int32_t llama_kv_cache_recurrent::s_copy(int i) const {
612
- const uint32_t cell_id = i + head;
613
-
614
- //////////////////////////////////////////////
615
- // TODO: this should not mutate the KV cache !
616
- kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
617
-
618
- // prevent out-of-bound sources
619
- if (cell.src < 0 || (uint32_t) cell.src >= size) {
620
- cell.src = cell_id;
621
- }
622
-
623
- int32_t res = cell.src;
624
-
625
- // TODO: do not mutate the KV cache
626
- // ensure copy only happens once
627
- if (cell.src != (int32_t) cell_id) {
628
- cell.src = cell_id;
629
- }
630
-
631
- return res;
632
- }
633
-
634
- float llama_kv_cache_recurrent::s_mask(int i) const {
635
- const uint32_t cell_id = i + head;
636
-
637
- //////////////////////////////////////////////
638
- // TODO: this should not mutate the KV cache !
639
- kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
640
-
641
- float res = (float) (cell.src >= 0);
642
-
643
- // only clear once
644
- if (cell.src < 0) {
645
- cell.src = cell_id;
646
- }
647
-
648
- return res;
632
+ bool llama_memory_recurrent::get_can_shift() const {
633
+ // shifting the pos is trivial for recurrent models
634
+ return true;
649
635
  }
650
636
 
651
- size_t llama_kv_cache_recurrent::total_size() const {
637
+ size_t llama_memory_recurrent::total_size() const {
652
638
  size_t size = 0;
653
639
  for (const auto & buf : bufs) {
654
640
  size += ggml_backend_buffer_get_size(buf.get());
@@ -657,27 +643,31 @@ size_t llama_kv_cache_recurrent::total_size() const {
657
643
  return size;
658
644
  }
659
645
 
660
- size_t llama_kv_cache_recurrent::size_k_bytes() const {
661
- size_t size_k_bytes = 0;
646
+ size_t llama_memory_recurrent::size_r_bytes() const {
647
+ size_t size_r_bytes = 0;
662
648
 
663
- for (const auto & k : k_l) {
664
- size_k_bytes += ggml_nbytes(k);
649
+ for (const auto & r : r_l) {
650
+ if (r != nullptr) {
651
+ size_r_bytes += ggml_nbytes(r);
652
+ }
665
653
  }
666
654
 
667
- return size_k_bytes;
655
+ return size_r_bytes;
668
656
  }
669
657
 
670
- size_t llama_kv_cache_recurrent::size_v_bytes() const {
671
- size_t size_v_bytes = 0;
658
+ size_t llama_memory_recurrent::size_s_bytes() const {
659
+ size_t size_s_bytes = 0;
672
660
 
673
- for (const auto & v : v_l) {
674
- size_v_bytes += ggml_nbytes(v);
661
+ for (const auto & s : s_l) {
662
+ if (s != nullptr) {
663
+ size_s_bytes += ggml_nbytes(s);
664
+ }
675
665
  }
676
666
 
677
- return size_v_bytes;
667
+ return size_s_bytes;
678
668
  }
679
669
 
680
- void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
670
+ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
681
671
  std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
682
672
  uint32_t cell_count = 0;
683
673
 
@@ -715,7 +705,7 @@ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id s
715
705
  state_write_data(io, cell_ranges);
716
706
  }
717
707
 
718
- void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
708
+ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
719
709
  uint32_t cell_count;
720
710
  io.read_to(&cell_count, sizeof(cell_count));
721
711
 
@@ -726,7 +716,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
726
716
 
727
717
  if (!res) {
728
718
  if (seq_id == -1) {
729
- clear();
719
+ clear(true);
730
720
  } else {
731
721
  seq_rm(seq_id, -1, -1);
732
722
  }
@@ -734,7 +724,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
734
724
  }
735
725
  }
736
726
 
737
- void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
727
+ void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
738
728
  for (const auto & range : cell_ranges) {
739
729
  for (uint32_t i = range.first; i < range.second; ++i) {
740
730
  const auto & cell = cells[i];
@@ -753,87 +743,85 @@ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std
753
743
  }
754
744
  }
755
745
 
756
- void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
757
- const uint32_t v_trans = 0;
746
+ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
747
+ const uint32_t s_trans = 0;
758
748
  const uint32_t n_layer = hparams.n_layer;
759
749
 
760
- io.write(&v_trans, sizeof(v_trans));
761
- io.write(&n_layer, sizeof(n_layer));
750
+ io.write(&s_trans, sizeof(s_trans));
751
+ io.write(&n_layer, sizeof(n_layer));
762
752
 
763
753
  std::vector<uint8_t> tmp_buf;
764
754
 
765
755
  // Iterate and write all the keys first, each row is a cell
766
756
  // Get whole range at a time
767
757
  for (uint32_t il = 0; il < n_layer; ++il) {
768
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
769
758
 
770
759
  // Write key type
771
- const int32_t k_type_i = (int32_t)k_l[il]->type;
772
- io.write(&k_type_i, sizeof(k_type_i));
760
+ const int32_t r_type_i = (int32_t)r_l[il]->type;
761
+ io.write(&r_type_i, sizeof(r_type_i));
773
762
 
774
763
  // Write row size of key
775
- const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
776
- io.write(&k_size_row, sizeof(k_size_row));
764
+ const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
765
+ io.write(&r_size_row, sizeof(r_size_row));
777
766
 
778
767
  // Read each range of cells of k_size length each into tmp_buf and write out
779
768
  for (const auto & range : cell_ranges) {
780
769
  const size_t range_size = range.second - range.first;
781
- const size_t buf_size = range_size * k_size_row;
782
- io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
770
+ const size_t buf_size = range_size * r_size_row;
771
+ io.write_tensor(r_l[il], range.first * r_size_row, buf_size);
783
772
  }
784
773
  }
785
774
 
786
- if (!v_trans) {
775
+ if (!s_trans) {
787
776
  for (uint32_t il = 0; il < n_layer; ++il) {
788
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
789
777
 
790
778
  // Write value type
791
- const int32_t v_type_i = (int32_t)v_l[il]->type;
792
- io.write(&v_type_i, sizeof(v_type_i));
779
+ const int32_t s_type_i = (int32_t)s_l[il]->type;
780
+ io.write(&s_type_i, sizeof(s_type_i));
793
781
 
794
782
  // Write row size of value
795
- const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
796
- io.write(&v_size_row, sizeof(v_size_row));
783
+ const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
784
+ io.write(&s_size_row, sizeof(s_size_row));
797
785
 
798
- // Read each range of cells of v_size length each into tmp_buf and write out
786
+ // Read each range of cells of s_size length each into tmp_buf and write out
799
787
  for (const auto & range : cell_ranges) {
800
788
  const size_t range_size = range.second - range.first;
801
- const size_t buf_size = range_size * v_size_row;
802
- io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
789
+ const size_t buf_size = range_size * s_size_row;
790
+ io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
803
791
  }
804
792
  }
805
793
  } else {
806
794
  // When v is transposed, we also need the element size and get the element ranges from each row
807
- const uint32_t kv_size = size;
795
+ const uint32_t mem_size = size;
808
796
  for (uint32_t il = 0; il < n_layer; ++il) {
809
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
797
+ const uint32_t n_embd_s = hparams.n_embd_s();
810
798
 
811
799
  // Write value type
812
- const int32_t v_type_i = (int32_t)v_l[il]->type;
813
- io.write(&v_type_i, sizeof(v_type_i));
800
+ const int32_t s_type_i = (int32_t)s_l[il]->type;
801
+ io.write(&s_type_i, sizeof(s_type_i));
814
802
 
815
803
  // Write element size
816
- const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
817
- io.write(&v_size_el, sizeof(v_size_el));
804
+ const uint32_t s_size_el = ggml_type_size(s_l[il]->type);
805
+ io.write(&s_size_el, sizeof(s_size_el));
818
806
 
819
807
  // Write GQA embedding size
820
- io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
808
+ io.write(&n_embd_s, sizeof(n_embd_s));
821
809
 
822
810
  // For each row, we get the element values of each cell
823
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
811
+ for (uint32_t j = 0; j < n_embd_s; ++j) {
824
812
  // Read each range of cells of v_size_el length each into tmp_buf and write out
825
813
  for (const auto & range : cell_ranges) {
826
814
  const size_t range_size = range.second - range.first;
827
- const size_t src_offset = (range.first + j * kv_size) * v_size_el;
828
- const size_t buf_size = range_size * v_size_el;
829
- io.write_tensor(v_l[il], src_offset, buf_size);
815
+ const size_t src_offset = (range.first + j * mem_size) * s_size_el;
816
+ const size_t buf_size = range_size * s_size_el;
817
+ io.write_tensor(s_l[il], src_offset, buf_size);
830
818
  }
831
819
  }
832
820
  }
833
821
  }
834
822
  }
835
823
 
836
- bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
824
+ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
837
825
  if (dest_seq_id != -1) {
838
826
  // single sequence
839
827
 
@@ -883,10 +871,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
883
871
  return false;
884
872
  }
885
873
 
886
- clear();
874
+ clear(true);
887
875
 
888
876
  for (uint32_t i = 0; i < cell_count; ++i) {
889
- kv_cell & cell = cells[i];
877
+ auto & cell = cells[i];
890
878
 
891
879
  llama_pos pos;
892
880
  uint32_t n_seq_id;
@@ -900,7 +888,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
900
888
  llama_seq_id seq_id;
901
889
  io.read_to(&seq_id, sizeof(seq_id));
902
890
 
903
- // TODO: llama_kv_cache_recurrent should have a notion of max sequences
891
+ // TODO: llama_memory_recurrent should have a notion of max sequences
904
892
  //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
905
893
  if (seq_id < 0) {
906
894
  //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
@@ -932,10 +920,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
932
920
  return true;
933
921
  }
934
922
 
935
- bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
936
- uint32_t v_trans;
923
+ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
924
+ uint32_t s_trans;
937
925
  uint32_t n_layer;
938
- io.read_to(&v_trans, sizeof(v_trans));
926
+ io.read_to(&s_trans, sizeof(s_trans));
939
927
  io.read_to(&n_layer, sizeof(n_layer));
940
928
 
941
929
  if (n_layer != hparams.n_layer) {
@@ -946,102 +934,100 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
946
934
  LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
947
935
  return false;
948
936
  }
949
- if (false != (bool) v_trans) {
950
- LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
937
+ if (false != (bool) s_trans) {
938
+ LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__);
951
939
  return false;
952
940
  }
953
941
 
954
942
  // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
955
943
  for (uint32_t il = 0; il < n_layer; ++il) {
956
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
957
944
 
958
945
  // Read type of key
959
- int32_t k_type_i_ref;
960
- io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
961
- const int32_t k_type_i = (int32_t) k_l[il]->type;
962
- if (k_type_i != k_type_i_ref) {
963
- LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
946
+ int32_t r_type_i_ref;
947
+ io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
948
+ const int32_t r_type_i = (int32_t) r_l[il]->type;
949
+ if (r_type_i != r_type_i_ref) {
950
+ LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
964
951
  return false;
965
952
  }
966
953
 
967
954
  // Read row size of key
968
- uint64_t k_size_row_ref;
969
- io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
970
- const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
971
- if (k_size_row != k_size_row_ref) {
972
- LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
955
+ uint64_t r_size_row_ref;
956
+ io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
957
+ const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
958
+ if (r_size_row != r_size_row_ref) {
959
+ LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
973
960
  return false;
974
961
  }
975
962
 
976
963
  if (cell_count) {
977
964
  // Read and set the keys for the whole cell range
978
- ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
965
+ ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
979
966
  }
980
967
  }
981
968
 
982
- if (!v_trans) {
969
+ if (!s_trans) {
983
970
  for (uint32_t il = 0; il < n_layer; ++il) {
984
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
985
971
 
986
972
  // Read type of value
987
- int32_t v_type_i_ref;
988
- io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
989
- const int32_t v_type_i = (int32_t)v_l[il]->type;
990
- if (v_type_i != v_type_i_ref) {
991
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
973
+ int32_t s_type_i_ref;
974
+ io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
975
+ const int32_t s_type_i = (int32_t)s_l[il]->type;
976
+ if (s_type_i != s_type_i_ref) {
977
+ LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
992
978
  return false;
993
979
  }
994
980
 
995
981
  // Read row size of value
996
- uint64_t v_size_row_ref;
997
- io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
998
- const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
999
- if (v_size_row != v_size_row_ref) {
1000
- LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
982
+ uint64_t s_size_row_ref;
983
+ io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
984
+ const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
985
+ if (s_size_row != s_size_row_ref) {
986
+ LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
1001
987
  return false;
1002
988
  }
1003
989
 
1004
990
  if (cell_count) {
1005
991
  // Read and set the values for the whole cell range
1006
- ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
992
+ ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
1007
993
  }
1008
994
  }
1009
995
  } else {
1010
996
  // For each layer, read the values for each cell (transposed)
1011
997
  for (uint32_t il = 0; il < n_layer; ++il) {
1012
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
998
+ const uint32_t n_embd_s = hparams.n_embd_s();
1013
999
 
1014
1000
  // Read type of value
1015
- int32_t v_type_i_ref;
1016
- io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
1017
- const int32_t v_type_i = (int32_t)v_l[il]->type;
1018
- if (v_type_i != v_type_i_ref) {
1019
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
1001
+ int32_t s_type_i_ref;
1002
+ io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
1003
+ const int32_t s_type_i = (int32_t)s_l[il]->type;
1004
+ if (s_type_i != s_type_i_ref) {
1005
+ LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
1020
1006
  return false;
1021
1007
  }
1022
1008
 
1023
1009
  // Read element size of value
1024
- uint32_t v_size_el_ref;
1025
- io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
1026
- const size_t v_size_el = ggml_type_size(v_l[il]->type);
1027
- if (v_size_el != v_size_el_ref) {
1028
- LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
1010
+ uint32_t s_size_el_ref;
1011
+ io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
1012
+ const size_t s_size_el = ggml_type_size(s_l[il]->type);
1013
+ if (s_size_el != s_size_el_ref) {
1014
+ LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
1029
1015
  return false;
1030
1016
  }
1031
1017
 
1032
- // Read GQA embedding size
1033
- uint32_t n_embd_v_gqa_ref;
1034
- io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
1035
- if (n_embd_v_gqa != n_embd_v_gqa_ref) {
1036
- LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
1018
+ // Read state embedding size
1019
+ uint32_t n_embd_s_ref;
1020
+ io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
1021
+ if (n_embd_s != n_embd_s_ref) {
1022
+ LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
1037
1023
  return false;
1038
1024
  }
1039
1025
 
1040
1026
  if (cell_count) {
1041
1027
  // For each row in the transposed matrix, read the values for the whole cell range
1042
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1043
- const size_t dst_offset = (head + j * size) * v_size_el;
1044
- ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
1028
+ for (uint32_t j = 0; j < n_embd_s; ++j) {
1029
+ const size_t dst_offset = (head + j * size) * s_size_el;
1030
+ ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
1045
1031
  }
1046
1032
  }
1047
1033
  }
@@ -1051,25 +1037,23 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
1051
1037
  }
1052
1038
 
1053
1039
  //
1054
- // llama_kv_cache_recurrent_state
1040
+ // llama_memory_recurrent_state
1055
1041
  //
1056
1042
 
1057
- llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
1043
+ llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
1058
1044
 
1059
- llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
1060
- llama_memory_status status,
1061
- llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
1045
+ llama_memory_recurrent_state::llama_memory_recurrent_state(
1046
+ llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
1062
1047
  }
1063
1048
 
1064
- llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
1065
- llama_memory_status status,
1066
- llama_kv_cache_recurrent * kv,
1049
+ llama_memory_recurrent_state::llama_memory_recurrent_state(
1050
+ llama_memory_recurrent * mem,
1067
1051
  llama_sbatch sbatch,
1068
- std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
1052
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
1069
1053
 
1070
- llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
1054
+ llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
1071
1055
 
1072
- bool llama_kv_cache_recurrent_state::next() {
1056
+ bool llama_memory_recurrent_state::next() {
1073
1057
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1074
1058
 
1075
1059
  if (++i_next >= ubatches.size()) {
@@ -1079,54 +1063,54 @@ bool llama_kv_cache_recurrent_state::next() {
1079
1063
  return true;
1080
1064
  }
1081
1065
 
1082
- bool llama_kv_cache_recurrent_state::apply() {
1066
+ bool llama_memory_recurrent_state::apply() {
1083
1067
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1084
1068
 
1085
- kv->find_slot(ubatches[i_next]);
1069
+ mem->find_slot(ubatches[i_next]);
1086
1070
 
1087
1071
  return true;
1088
1072
  }
1089
1073
 
1090
- std::vector<int64_t> & llama_kv_cache_recurrent_state::out_ids() {
1074
+ std::vector<int64_t> & llama_memory_recurrent_state::out_ids() {
1091
1075
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1092
1076
 
1093
1077
  return sbatch.out_ids;
1094
1078
  }
1095
1079
 
1096
- llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
1080
+ llama_memory_status llama_memory_recurrent_state::get_status() const {
1097
1081
  return status;
1098
1082
  }
1099
1083
 
1100
- const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const {
1084
+ const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
1101
1085
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1102
1086
 
1103
1087
  return ubatches[i_next];
1104
1088
  }
1105
1089
 
1106
- uint32_t llama_kv_cache_recurrent_state::get_n_kv() const {
1107
- return is_full ? kv->size : kv->n;
1090
+ uint32_t llama_memory_recurrent_state::get_n_rs() const {
1091
+ return is_full ? mem->size : mem->n;
1108
1092
  }
1109
1093
 
1110
- uint32_t llama_kv_cache_recurrent_state::get_head() const {
1111
- return is_full ? 0 : kv->head;
1094
+ uint32_t llama_memory_recurrent_state::get_head() const {
1095
+ return is_full ? 0 : mem->head;
1112
1096
  }
1113
1097
 
1114
- uint32_t llama_kv_cache_recurrent_state::get_size() const {
1115
- return kv->size;
1098
+ int32_t llama_memory_recurrent_state::get_rs_z() const {
1099
+ return is_full ? 0 : mem->rs_z;
1116
1100
  }
1117
1101
 
1118
- ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const {
1119
- return kv->k_l[il];
1102
+ uint32_t llama_memory_recurrent_state::get_size() const {
1103
+ return mem->size;
1120
1104
  }
1121
1105
 
1122
- ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
1123
- return kv->v_l[il];
1106
+ ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
1107
+ return mem->r_l[il];
1124
1108
  }
1125
1109
 
1126
- int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
1127
- return kv->s_copy(i);
1110
+ ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
1111
+ return mem->s_l[il];
1128
1112
  }
1129
1113
 
1130
- float llama_kv_cache_recurrent_state::s_mask(int i) const {
1131
- return kv->s_mask(i);
1114
+ int32_t llama_memory_recurrent_state::s_copy(int i) const {
1115
+ return mem->cells[i + mem->head].src0;
1132
1116
  }