@novastera-oss/llamarn 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. package/android/src/main/cpp/include/llama.h +134 -36
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +2 -2
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +30 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +50 -40
  26. package/cpp/llama.cpp/common/common.h +5 -2
  27. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  28. package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  30. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  35. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  70. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  84. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  101. package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
  102. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  103. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  104. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  105. package/cpp/llama.cpp/include/llama.h +134 -36
  106. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  107. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  108. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  109. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  110. package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
  111. package/cpp/llama.cpp/src/llama-batch.h +36 -11
  112. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  113. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  114. package/cpp/llama.cpp/src/llama-context.cpp +313 -213
  115. package/cpp/llama.cpp/src/llama-context.h +16 -12
  116. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  117. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  118. package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
  119. package/cpp/llama.cpp/src/llama-graph.h +90 -34
  120. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  121. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  122. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
  123. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  124. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
  125. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
  126. package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
  127. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  128. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  129. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
  130. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
  131. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  132. package/cpp/llama.cpp/src/llama-memory.h +64 -23
  133. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  134. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  135. package/cpp/llama.cpp/src/llama-model.cpp +726 -141
  136. package/cpp/llama.cpp/src/llama-model.h +4 -0
  137. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  138. package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
  139. package/cpp/llama.cpp/src/llama.cpp +11 -7
  140. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  141. package/cpp/rn-completion.cpp +2 -2
  142. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  143. package/ios/include/chat.h +1 -1
  144. package/ios/include/common.h +5 -2
  145. package/ios/include/llama.h +134 -36
  146. package/ios/libs/llama.xcframework/Info.plist +18 -18
  147. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  148. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  149. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
  150. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  151. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  152. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  153. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  154. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  155. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
  160. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
  161. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  162. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  165. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  167. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
  168. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  173. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  175. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  178. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/package.json +1 -2
  184. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  185. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  186. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  187. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  188. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  189. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  190. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  191. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  192. /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
@@ -1,7 +1,6 @@
1
1
  #pragma once
2
2
 
3
3
  #include "llama.h"
4
- #include "llama-batch.h"
5
4
  #include "llama-cparams.h"
6
5
  #include "llama-graph.h"
7
6
  #include "llama-adapter.h"
@@ -13,13 +12,13 @@
13
12
  #include <vector>
14
13
 
15
14
  struct llama_model;
16
- struct llama_kv_cache;
15
+ class llama_batch_allocr;
17
16
 
18
17
  class llama_io_read_i;
19
18
  class llama_io_write_i;
20
19
 
21
- class llama_memory_i;
22
- class llama_memory_state_i;
20
+ struct llama_memory_i;
21
+ struct llama_memory_state_i;
23
22
 
24
23
  struct llama_context {
25
24
  // init scheduler and compute buffers, reserve worst-case graphs
@@ -47,12 +46,12 @@ struct llama_context {
47
46
  uint32_t n_threads() const;
48
47
  uint32_t n_threads_batch() const;
49
48
 
50
- llama_kv_cache * get_kv_self();
51
- const llama_kv_cache * get_kv_self() const;
49
+ llama_memory_t get_memory() const;
52
50
 
53
51
  // return true of the KV cache was updated
54
52
  // TODO: remove
55
- bool kv_self_update();
53
+ bool kv_self_update(bool optimize);
54
+ void kv_self_defrag_sched();
56
55
 
57
56
  enum llama_pooling_type pooling_type() const;
58
57
 
@@ -103,8 +102,8 @@ struct llama_context {
103
102
  llama_memory_state_i * mstate,
104
103
  ggml_status & ret);
105
104
 
106
- int encode(llama_batch & inp_batch);
107
- int decode(llama_batch & inp_batch);
105
+ int encode(const llama_batch & batch_inp);
106
+ int decode(const llama_batch & batch_inp);
108
107
 
109
108
  //
110
109
  // state save/load
@@ -182,7 +181,7 @@ private:
182
181
 
183
182
  // Make sure enough space is available for outputs.
184
183
  // Returns max number of outputs for which space was reserved.
185
- int32_t output_reserve(int32_t n_outputs);
184
+ uint32_t output_reserve(int32_t n_outputs);
186
185
 
187
186
  //
188
187
  // graph
@@ -231,6 +230,9 @@ private:
231
230
 
232
231
  std::unique_ptr<llama_memory_i> memory;
233
232
 
233
+ // TODO: temporary, until the llama_kv_self_defrag() API is removed
234
+ bool memory_force_optimize = false;
235
+
234
236
  // decode output (2-dimensional array: [n_outputs][n_vocab])
235
237
  size_t logits_size = 0; // capacity (of floats) for logits
236
238
  float * logits = nullptr;
@@ -244,8 +246,10 @@ private:
244
246
  // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
245
247
  std::map<llama_seq_id, std::vector<float>> embd_seq;
246
248
 
247
- int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
248
- int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
249
+ // reuse the batch_allocr to avoid unnecessary memory allocations
250
+ std::unique_ptr<llama_batch_allocr> batch_allocr;
251
+
252
+ uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
249
253
 
250
254
  std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
251
255
 
@@ -1,5 +1,5 @@
1
1
  #include "llama-cparams.h"
2
2
 
3
3
  size_t llama_max_parallel_sequences(void) {
4
- return LLAMA_MAX_PARALLEL_SEQUENCES;
4
+ return LLAMA_MAX_SEQ;
5
5
  }
@@ -4,7 +4,7 @@
4
4
 
5
5
  #include <cstdint>
6
6
 
7
- #define LLAMA_MAX_PARALLEL_SEQUENCES 64
7
+ #define LLAMA_MAX_SEQ 64
8
8
 
9
9
  struct llama_cparams {
10
10
  uint32_t n_ctx; // context size used during inference
@@ -6,7 +6,8 @@
6
6
 
7
7
  #include "llama-kv-cache-unified.h"
8
8
  #include "llama-kv-cache-unified-iswa.h"
9
- #include "llama-kv-cache-recurrent.h"
9
+ #include "llama-memory-hybrid.h"
10
+ #include "llama-memory-recurrent.h"
10
11
 
11
12
  #include <cassert>
12
13
  #include <cmath>
@@ -139,6 +140,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
139
140
 
140
141
  std::vector<uint64_t> sum(n_tokens, 0);
141
142
 
143
+ // TODO: fix indexing [UBATCH_IDX]
142
144
  for (int s = 0; s < n_seqs; ++s) {
143
145
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
144
146
 
@@ -156,6 +158,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
156
158
  }
157
159
  }
158
160
 
161
+ // TODO: fix indexing [UBATCH_IDX]
159
162
  for (int s = 0; s < n_seqs; ++s) {
160
163
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
161
164
 
@@ -180,6 +183,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
180
183
  uint32_t * data = (uint32_t *) cls->data;
181
184
  memset(cls->data, 0, n_tokens * ggml_element_size(cls));
182
185
 
186
+ // TODO: fix indexing [UBATCH_IDX]
183
187
  for (int s = 0; s < n_seqs; ++s) {
184
188
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
185
189
 
@@ -210,6 +214,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
210
214
  std::vector<int> last_pos(n_tokens, -1);
211
215
  std::vector<int> last_row(n_tokens, -1);
212
216
 
217
+ // TODO: fix indexing [UBATCH_IDX]
213
218
  for (int s = 0; s < n_seqs; ++s) {
214
219
  const llama_seq_id seq_id = ubatch->seq_id[s][0];
215
220
 
@@ -234,34 +239,18 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
234
239
  }
235
240
  }
236
241
 
237
- void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
242
+ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
238
243
  GGML_UNUSED(ubatch);
239
244
 
240
- const int64_t n_kv = kv_state->get_n_kv();
245
+ const int64_t n_rs = mem_state->get_n_rs();
241
246
 
242
247
  if (s_copy) {
243
248
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
244
249
  int32_t * data = (int32_t *) s_copy->data;
245
250
 
246
251
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
247
- for (uint32_t i = 0; i < n_kv; ++i) {
248
- data[i] = kv_state->s_copy(i);
249
- }
250
- }
251
- }
252
-
253
- void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
254
- GGML_UNUSED(ubatch);
255
-
256
- const int64_t n_kv = kv_state->get_n_kv();
257
-
258
- if (s_mask) {
259
- GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
260
- float * data = (float *) s_mask->data;
261
-
262
- // clear unused states
263
- for (int i = 0; i < n_kv; ++i) {
264
- data[i] = kv_state->s_mask(i);
252
+ for (uint32_t i = 0; i < n_rs; ++i) {
253
+ data[i] = mem_state->s_copy(i);
265
254
  }
266
255
  }
267
256
  }
@@ -299,6 +288,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
299
288
  const int32_t ti = s0*n_seq_tokens + i;
300
289
  float f = -INFINITY;
301
290
 
291
+ // TODO: fix indexing [UBATCH_IDX]
302
292
  for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
303
293
  if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
304
294
  if (hparams.use_alibi) {
@@ -338,6 +328,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
338
328
  const int32_t ti = s0*n_seq_tokens + i;
339
329
  float f = -INFINITY;
340
330
 
331
+ // TODO: fix indexing [UBATCH_IDX]
341
332
  for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
342
333
  if (ubatch->seq_id[s0][s] == seq_id) {
343
334
  if (hparams.use_alibi) {
@@ -393,6 +384,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
393
384
  for (int j = 0; j < n_tokens; ++j) {
394
385
  for (int i = 0; i < n_enc; ++i) {
395
386
  float f = -INFINITY;
387
+ // TODO: fix indexing [UBATCH_IDX]
396
388
  for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
397
389
  const llama_seq_id seq_id = ubatch->seq_id[j][s];
398
390
  if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
@@ -412,6 +404,24 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
412
404
  }
413
405
  }
414
406
 
407
+ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
408
+ if (self_kq_mask) {
409
+ mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
410
+ }
411
+
412
+ const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
413
+
414
+ if (s_copy) {
415
+ GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
416
+ int32_t * data = (int32_t *) s_copy->data;
417
+
418
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
419
+ for (uint32_t i = 0; i < n_rs; ++i) {
420
+ data[i] = mem_state->get_state_recr()->s_copy(i);
421
+ }
422
+ }
423
+ }
424
+
415
425
  //
416
426
  // llm_graph_context
417
427
  //
@@ -650,6 +660,7 @@ ggml_tensor * llm_graph_context::build_ffn(
650
660
  {
651
661
  // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
652
662
  int64_t split_point = cur->ne[0] / 2;
663
+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
653
664
  ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
654
665
  ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
655
666
 
@@ -659,6 +670,20 @@ ggml_tensor * llm_graph_context::build_ffn(
659
670
  cur = ggml_mul(ctx0, x0, x1);
660
671
  cb(cur, "ffn_mul", il);
661
672
  } break;
673
+ case LLM_FFN_GEGLU:
674
+ {
675
+ // Split into two equal parts
676
+ int64_t split_point = cur->ne[0] / 2;
677
+ // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
678
+ ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
679
+ ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
680
+
681
+ x0 = ggml_gelu(ctx0, x0);
682
+ cb(x0, "ffn_gelu", il);
683
+
684
+ cur = ggml_mul(ctx0, x0, x1);
685
+ cb(cur, "ffn_geglu", il);
686
+ } break;
662
687
  }
663
688
 
664
689
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -769,9 +794,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
769
794
  cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
770
795
 
771
796
  if (weight_before_ffn) {
772
- // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
773
- ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
774
- repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
797
+ // repeat cur to [n_embd, n_expert_used, n_tokens]
798
+ ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
775
799
  cur = ggml_mul(ctx0, repeated, weights);
776
800
  cb(cur, "ffn_moe_weighted", il);
777
801
  }
@@ -956,40 +980,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
956
980
  return cur;
957
981
  }
958
982
 
959
- ggml_tensor * llm_graph_context::build_inp_s_copy() const {
960
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
961
-
962
- auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
963
-
964
- const auto n_kv = kv_state->get_n_kv();
965
-
966
- auto & cur = inp->s_copy;
967
-
968
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
969
- ggml_set_input(cur);
970
-
971
- res->add_input(std::move(inp));
972
-
973
- return cur;
974
- }
975
-
976
- ggml_tensor * llm_graph_context::build_inp_s_mask() const {
977
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
978
-
979
- auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
980
-
981
- const auto n_kv = kv_state->get_n_kv();
982
-
983
- auto & cur = inp->s_mask;
984
-
985
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
986
- ggml_set_input(cur);
987
-
988
- res->add_input(std::move(inp));
989
-
990
- return cur;
991
- }
992
-
993
983
  ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
994
984
  auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
995
985
 
@@ -1059,6 +1049,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
1059
1049
  return pos_bias;
1060
1050
  }
1061
1051
 
1052
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1053
+ const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
1054
+
1055
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
1056
+
1057
+ {
1058
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
1059
+
1060
+ const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
1061
+
1062
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1063
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1064
+ ggml_set_input(inp->self_kq_mask);
1065
+
1066
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1067
+ }
1068
+
1069
+ {
1070
+ const auto n_rs = mem_state->get_state_recr()->get_n_rs();
1071
+
1072
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1073
+ ggml_set_input(inp->s_copy);
1074
+ }
1075
+
1076
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1077
+ }
1078
+
1062
1079
  ggml_tensor * llm_graph_context::build_attn_mha(
1063
1080
  ggml_cgraph * gf,
1064
1081
  ggml_tensor * q,
@@ -1303,36 +1320,6 @@ ggml_tensor * llm_graph_context::build_attn(
1303
1320
  return cur;
1304
1321
  }
1305
1322
 
1306
- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1307
- const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1308
-
1309
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
1310
-
1311
- {
1312
- const auto n_kv = kv_state->get_base()->get_n_kv();
1313
-
1314
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1315
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1316
- ggml_set_input(inp->self_kq_mask);
1317
-
1318
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1319
- }
1320
-
1321
- {
1322
- GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1323
-
1324
- const auto n_kv = kv_state->get_swa()->get_n_kv();
1325
-
1326
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1327
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1328
- ggml_set_input(inp->self_kq_mask_swa);
1329
-
1330
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1331
- }
1332
-
1333
- return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1334
- }
1335
-
1336
1323
  ggml_tensor * llm_graph_context::build_attn(
1337
1324
  llm_graph_input_attn_kv_unified_iswa * inp,
1338
1325
  ggml_cgraph * gf,
@@ -1442,56 +1429,182 @@ ggml_tensor * llm_graph_context::build_attn(
1442
1429
  return cur;
1443
1430
  }
1444
1431
 
1445
- ggml_tensor * llm_graph_context::build_copy_mask_state(
1446
- ggml_cgraph * gf,
1447
- ggml_tensor * s,
1448
- ggml_tensor * state_copy,
1449
- ggml_tensor * state_mask,
1450
- int32_t n_state,
1451
- int32_t n_seqs) const {
1452
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1453
-
1454
- const auto n_kv = kv_state->get_n_kv();
1455
- const auto kv_head = kv_state->get_head();
1432
+ ggml_tensor * llm_graph_context::build_attn(
1433
+ llm_graph_input_mem_hybrid * inp,
1434
+ ggml_cgraph * gf,
1435
+ ggml_tensor * wo,
1436
+ ggml_tensor * wo_b,
1437
+ ggml_tensor * q_cur,
1438
+ ggml_tensor * k_cur,
1439
+ ggml_tensor * v_cur,
1440
+ ggml_tensor * kq_b,
1441
+ ggml_tensor * v_mla,
1442
+ float kq_scale,
1443
+ int il) const {
1444
+ // these nodes are added to the graph together so that they are not reordered
1445
+ // by doing so, the number of splits in the graph is reduced
1446
+ ggml_build_forward_expand(gf, q_cur);
1447
+ ggml_build_forward_expand(gf, k_cur);
1448
+ ggml_build_forward_expand(gf, v_cur);
1449
+
1450
+ const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
1451
+
1452
+ // store to KV cache
1453
+ {
1454
+ ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1455
+ ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1456
+ }
1457
+
1458
+ const auto & kq_mask = inp->get_kq_mask();
1459
+
1460
+ ggml_tensor * q = q_cur;
1461
+ ggml_tensor * k = kv_state->get_k(ctx0, il);
1462
+ ggml_tensor * v = kv_state->get_v(ctx0, il);
1463
+
1464
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1465
+ cb(cur, "kqv_out", il);
1466
+
1467
+ if (wo) {
1468
+ cur = build_lora_mm(wo, cur);
1469
+ if (arch == LLM_ARCH_GLM4) {
1470
+ // GLM4 seems to have numerical issues with half-precision accumulators
1471
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1472
+ }
1473
+ }
1474
+
1475
+ if (wo_b) {
1476
+ cur = ggml_add(ctx0, cur, wo_b);
1477
+ }
1478
+
1479
+ return cur;
1480
+ }
1456
1481
 
1457
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
1482
+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1483
+ const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
1484
+
1485
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
1486
+
1487
+ {
1488
+ const auto n_kv = kv_state->get_base()->get_n_kv();
1489
+
1490
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1491
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1492
+ ggml_set_input(inp->self_kq_mask);
1493
+
1494
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1495
+ }
1496
+
1497
+ {
1498
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1499
+
1500
+ const auto n_kv = kv_state->get_swa()->get_n_kv();
1501
+
1502
+ inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1503
+ //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1504
+ ggml_set_input(inp->self_kq_mask_swa);
1505
+
1506
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1507
+ }
1458
1508
 
1459
- // copy states
1460
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1461
- // this shrinks the tensors's ne[1] to n_kv
1462
- states = ggml_get_rows(ctx0, states, state_copy);
1509
+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1510
+ }
1463
1511
 
1464
- // clear states of sequences which are starting at the beginning of this batch
1465
- // FIXME: zero-out NANs?
1466
- states = ggml_mul(ctx0, states, state_mask);
1512
+ ggml_tensor * llm_graph_context::build_rs(
1513
+ ggml_cgraph * gf,
1514
+ ggml_tensor * s,
1515
+ ggml_tensor * state_copy,
1516
+ int32_t state_size,
1517
+ int32_t n_seqs,
1518
+ uint32_t n_kv,
1519
+ uint32_t kv_head,
1520
+ uint32_t kv_size,
1521
+ int32_t rs_zero,
1522
+ bool avoid_copies) const {
1523
+
1524
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1525
+
1526
+ // Clear a single state which will then be copied to the other cleared states.
1527
+ // Note that this is a no-op when the view is zero-sized.
1528
+ ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1529
+ ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1530
+
1531
+ ggml_tensor * output_states;
1532
+
1533
+ if (!avoid_copies) {
1534
+ // copy states
1535
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1536
+ // {state_size, kv_size} -> {state_size, n_seqs}
1537
+ output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1538
+ ggml_build_forward_expand(gf, output_states);
1539
+ } else {
1540
+ // FIXME: make the gathering operation happen before the copy below
1541
+ // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1542
+ output_states = states;
1543
+ }
1467
1544
 
1468
- // copy states which won't be changed further (between n_seqs and n_kv)
1545
+ // copy extra states which won't be changed further (between n_seqs and n_kv)
1546
+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1469
1547
  ggml_build_forward_expand(gf,
1470
1548
  ggml_cpy(ctx0,
1471
- ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
1472
- ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1549
+ states_extra,
1550
+ ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1473
1551
 
1474
- // the part of the states that will be used and modified
1475
- return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
1552
+ return output_states;
1553
+ }
1554
+
1555
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1556
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1557
+
1558
+ auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1559
+
1560
+ const auto n_rs = kv_state->get_n_rs();
1561
+
1562
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1563
+ ggml_set_input(inp->s_copy);
1564
+
1565
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1566
+ }
1567
+
1568
+ ggml_tensor * llm_graph_context::build_rs(
1569
+ llm_graph_input_rs * inp,
1570
+ ggml_cgraph * gf,
1571
+ ggml_tensor * s,
1572
+ int32_t state_size,
1573
+ int32_t n_seqs,
1574
+ bool avoid_copies) const {
1575
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1576
+
1577
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1578
+ }
1579
+
1580
+ ggml_tensor * llm_graph_context::build_rs(
1581
+ llm_graph_input_mem_hybrid * inp,
1582
+ ggml_cgraph * gf,
1583
+ ggml_tensor * s,
1584
+ int32_t state_size,
1585
+ int32_t n_seqs,
1586
+ bool avoid_copies) const {
1587
+ const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
1588
+
1589
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1476
1590
  }
1477
1591
 
1478
1592
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1479
- ggml_cgraph * gf,
1480
- ggml_tensor * state_copy,
1481
- ggml_tensor * state_mask,
1482
- const llama_ubatch & ubatch,
1593
+ llm_graph_input_rs * inp,
1594
+ ggml_cgraph * gf,
1595
+ const llama_ubatch & ubatch,
1483
1596
  int il) const {
1484
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1597
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1485
1598
 
1486
1599
  const auto token_shift_count = hparams.token_shift_count;
1487
1600
 
1488
1601
  const int64_t n_seqs = ubatch.n_seqs;
1489
1602
 
1490
- ggml_tensor * token_shift_all = kv_state->get_k_l(il);
1603
+ ggml_tensor * token_shift_all = kv_state->get_r_l(il);
1491
1604
 
1492
- ggml_tensor * token_shift = build_copy_mask_state(
1493
- gf, token_shift_all, state_copy, state_mask,
1494
- hparams.n_embd_k_s(), n_seqs);
1605
+ ggml_tensor * token_shift = build_rs(
1606
+ inp, gf, token_shift_all,
1607
+ hparams.n_embd_r(), n_seqs);
1495
1608
 
1496
1609
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
1497
1610
 
@@ -1502,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1502
1615
  ggml_tensor * token_shift,
1503
1616
  const llama_ubatch & ubatch,
1504
1617
  int il) const {
1505
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1618
+ const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1506
1619
 
1507
1620
  const auto token_shift_count = hparams.token_shift_count;
1508
1621
  const auto n_embd = hparams.n_embd;
@@ -1514,7 +1627,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1514
1627
  return ggml_cpy(
1515
1628
  ctx0,
1516
1629
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1517
- ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
1630
+ ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
1518
1631
  );
1519
1632
  }
1520
1633
 
@@ -1565,23 +1678,30 @@ void llm_graph_context::build_pooling(
1565
1678
  ggml_tensor * inp_cls = build_inp_cls();
1566
1679
  inp = ggml_get_rows(ctx0, inp, inp_cls);
1567
1680
 
1568
- if (cls != nullptr && cls_b != nullptr) {
1681
+ if (cls) {
1569
1682
  // classification head
1570
1683
  // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1571
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1684
+ cur = ggml_mul_mat(ctx0, cls, inp);
1685
+ if (cls_b) {
1686
+ cur = ggml_add(ctx0, cur, cls_b);
1687
+ }
1572
1688
  cur = ggml_tanh(ctx0, cur);
1573
1689
 
1574
1690
  // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1575
1691
  // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1576
1692
  if (cls_out) {
1577
- GGML_ASSERT(cls_out_b != nullptr);
1578
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1693
+ cur = ggml_mul_mat(ctx0, cls_out, cur);
1694
+ if (cls_out_b) {
1695
+ cur = ggml_add(ctx0, cur, cls_out_b);
1696
+ }
1579
1697
  }
1580
1698
  } else if (cls_out) {
1581
1699
  // Single layer classification head (direct projection)
1582
1700
  // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1583
- GGML_ASSERT(cls_out_b != nullptr);
1584
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
1701
+ cur = ggml_mul_mat(ctx0, cls_out, inp);
1702
+ if (cls_out_b) {
1703
+ cur = ggml_add(ctx0, cur, cls_out_b);
1704
+ }
1585
1705
  } else {
1586
1706
  GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
1587
1707
  }