whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -3,7 +3,11 @@
3
3
  #include "llama-impl.h"
4
4
  #include "llama-batch.h"
5
5
  #include "llama-cparams.h"
6
- #include "llama-kv-cache.h"
6
+
7
+ #include "llama-kv-cache-unified.h"
8
+ #include "llama-kv-cache-unified-iswa.h"
9
+ #include "llama-memory-hybrid.h"
10
+ #include "llama-memory-recurrent.h"
7
11
 
8
12
  #include <cassert>
9
13
  #include <cmath>
@@ -83,41 +87,33 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
83
87
 
84
88
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
85
89
  if (pos_bucket) {
86
- kv_self->set_input_pos_bucket(pos_bucket, ubatch);
90
+ mctx->set_input_pos_bucket(pos_bucket, ubatch);
87
91
  }
88
92
  }
89
93
 
90
94
  void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
91
- if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
92
- //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
95
+ GGML_ASSERT(out_ids);
93
96
 
94
- if (!out_ids) {
95
- LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
96
- } else {
97
- const int64_t n_tokens = ubatch->n_tokens;
97
+ const int64_t n_tokens = ubatch->n_tokens;
98
98
 
99
- GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
100
- int32_t * data = (int32_t *) out_ids->data;
99
+ GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
100
+ int32_t * data = (int32_t *) out_ids->data;
101
101
 
102
- if (n_outputs == n_tokens) {
103
- for (int i = 0; i < n_tokens; ++i) {
104
- data[i] = i;
105
- }
106
- } else if (ubatch->output) {
107
- int32_t n_outputs = 0;
108
- for (int i = 0; i < n_tokens; ++i) {
109
- if (ubatch->output[i]) {
110
- data[n_outputs++] = i;
111
- }
112
- }
113
- // the graph needs to have been passed the correct number of outputs
114
- GGML_ASSERT(n_outputs == n_outputs);
115
- } else if (n_outputs == 1) {
116
- // only keep last output
117
- data[0] = n_tokens - 1;
118
- } else {
119
- GGML_ASSERT(n_outputs == 0);
120
- }
102
+ if (n_outputs == n_tokens) {
103
+ for (int i = 0; i < n_tokens; ++i) {
104
+ data[i] = i;
105
+ }
106
+
107
+ return;
108
+ }
109
+
110
+ GGML_ASSERT(ubatch->output);
111
+
112
+ int n_outputs = 0;
113
+
114
+ for (int i = 0; i < n_tokens; ++i) {
115
+ if (ubatch->output[i]) {
116
+ data[n_outputs++] = i;
121
117
  }
122
118
  }
123
119
  }
@@ -126,139 +122,114 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
126
122
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
127
123
  const int64_t n_tokens = ubatch->n_tokens;
128
124
  const int64_t n_seq_tokens = ubatch->n_seq_tokens;
129
- const int64_t n_seqs = ubatch->n_seqs;
125
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
130
126
 
131
127
  GGML_ASSERT(mean);
132
128
  GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
133
129
 
134
130
  float * data = (float *) mean->data;
135
- memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
136
-
137
- std::vector<uint64_t> sum(n_tokens, 0);
131
+ memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
138
132
 
139
- for (int s = 0; s < n_seqs; ++s) {
140
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
133
+ std::vector<uint64_t> sums(n_seqs_unq, 0);
134
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
135
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
136
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
137
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
141
138
 
142
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
143
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
144
-
145
- sum[seq_id] += ubatch->n_seq_tokens;
139
+ sums[seq_idx] += ubatch->n_seq_tokens;
140
+ }
146
141
  }
147
142
 
148
- std::vector<float> div(n_tokens, 0.0f);
149
- for (int i = 0; i < n_tokens; ++i) {
150
- const uint64_t s = sum[i];
151
- if (s > 0) {
152
- div[i] = 1.0f/float(s);
143
+ std::vector<float> div(n_seqs_unq, 0.0f);
144
+ for (int s = 0; s < n_seqs_unq; ++s) {
145
+ const uint64_t sum = sums[s];
146
+ if (sum > 0) {
147
+ div[s] = 1.0f/float(sum);
153
148
  }
154
149
  }
155
150
 
156
- for (int s = 0; s < n_seqs; ++s) {
157
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
151
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
152
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
153
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
154
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
158
155
 
159
- for (int i = 0; i < n_seq_tokens; ++i) {
160
- data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
156
+ for (int j = 0; j < n_seq_tokens; ++j) {
157
+ data[seq_idx*n_tokens + i + j] = div[seq_idx];
158
+ }
161
159
  }
162
160
  }
163
161
  }
164
162
  }
165
163
 
166
164
  void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
167
- if (cparams.embeddings && (
168
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
169
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
170
- const int64_t n_tokens = ubatch->n_tokens;
171
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
172
- const int64_t n_seqs = ubatch->n_seqs;
165
+ const int64_t n_tokens = ubatch->n_tokens;
166
+ const int64_t n_seq_tokens = ubatch->n_seq_tokens;
167
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
173
168
 
169
+ if (cparams.embeddings && (
170
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
171
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
172
+ )) {
174
173
  GGML_ASSERT(cls);
175
174
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
176
175
 
177
176
  uint32_t * data = (uint32_t *) cls->data;
178
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
179
-
180
- for (int s = 0; s < n_seqs; ++s) {
181
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
177
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
182
178
 
183
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
184
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
179
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
180
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
181
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
182
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
185
183
 
186
- for (int i = 0; i < n_seq_tokens; ++i) {
187
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
188
-
189
- if (pos == 0) {
190
- data[seq_id] = s*n_seq_tokens + i;
191
- }
184
+ data[seq_idx] = i;
192
185
  }
193
186
  }
194
187
  }
195
188
 
196
189
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
197
- const int64_t n_tokens = ubatch->n_tokens;
198
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
199
- const int64_t n_seqs = ubatch->n_seqs;
200
-
201
190
  GGML_ASSERT(cls);
202
191
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
203
192
 
204
193
  uint32_t * data = (uint32_t *) cls->data;
205
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
206
-
207
- std::vector<int> last_pos(n_tokens, -1);
208
- std::vector<int> last_row(n_tokens, -1);
194
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
209
195
 
210
- for (int s = 0; s < n_seqs; ++s) {
211
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
196
+ std::vector<int> last_pos(n_seqs_unq, -1);
197
+ std::vector<int> last_row(n_seqs_unq, -1);
212
198
 
213
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
214
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
199
+ for (int i = 0; i < n_tokens; ++i) {
200
+ const llama_pos pos = ubatch->pos[i];
215
201
 
216
- for (int i = 0; i < n_seq_tokens; ++i) {
217
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
202
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
203
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
204
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
218
205
 
219
- if (pos >= last_pos[seq_id]) {
220
- last_pos[seq_id] = pos;
221
- last_row[seq_id] = s*n_seq_tokens + i;
206
+ if (pos >= last_pos[seq_idx]) {
207
+ last_pos[seq_idx] = pos;
208
+ last_row[seq_idx] = i;
222
209
  }
223
210
  }
224
211
  }
225
212
 
226
- for (int i = 0; i < n_tokens; ++i) {
227
- if (last_row[i] >= 0) {
228
- data[i] = last_row[i];
213
+ for (int s = 0; s < n_seqs_unq; ++s) {
214
+ if (last_row[s] >= 0) {
215
+ data[s] = last_row[s];
229
216
  }
230
217
  }
231
218
  }
232
219
  }
233
220
 
234
- void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
221
+ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
235
222
  GGML_UNUSED(ubatch);
236
223
 
237
- const int64_t n_kv = kv_self->n;
224
+ const int64_t n_rs = mctx->get_n_rs();
238
225
 
239
226
  if (s_copy) {
240
227
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
241
228
  int32_t * data = (int32_t *) s_copy->data;
242
229
 
243
230
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
244
- for (uint32_t i = 0; i < n_kv; ++i) {
245
- data[i] = kv_self->s_copy(i);
246
- }
247
- }
248
- }
249
-
250
- void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
251
- GGML_UNUSED(ubatch);
252
-
253
- const int64_t n_kv = kv_self->n;
254
-
255
- if (s_mask) {
256
- GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
257
- float * data = (float *) s_mask->data;
258
-
259
- // clear unused states
260
- for (int i = 0; i < n_kv; ++i) {
261
- data[i] = kv_self->s_mask(i);
231
+ for (uint32_t i = 0; i < n_rs; ++i) {
232
+ data[i] = mctx->s_copy(i);
262
233
  }
263
234
  }
264
235
  }
@@ -274,87 +245,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
274
245
  }
275
246
 
276
247
  void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
277
- if (kq_mask) {
278
- if (cparams.causal_attn) {
279
- const int64_t n_kv = ubatch->n_tokens;
280
- const int64_t n_tokens = ubatch->n_tokens;
281
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
282
- const int64_t n_seqs = ubatch->n_seqs;
283
-
284
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
285
- float * data = (float *) kq_mask->data;
286
-
287
- for (int h = 0; h < 1; ++h) {
288
- for (int s1 = 0; s1 < n_seqs; ++s1) {
289
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
290
-
291
- for (int j = 0; j < n_seq_tokens; ++j) {
292
- const int32_t tj = s1*n_seq_tokens + j;
293
-
294
- for (int s0 = 0; s0 < n_seqs; ++s0) {
295
- for (int i = 0; i < n_seq_tokens; ++i) {
296
- const int32_t ti = s0*n_seq_tokens + i;
297
- float f = -INFINITY;
298
-
299
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
300
- if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
301
- if (hparams.use_alibi) {
302
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
303
- } else {
304
- f = 0.0f;
305
- }
306
- break;
307
- }
308
- }
309
-
310
- data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
311
- }
312
- }
313
- }
314
- }
315
- }
316
- } else {
317
- const int64_t n_tokens = ubatch->n_tokens;
318
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
319
- const int64_t n_seqs = ubatch->n_seqs;
320
- const int64_t n_stride = ubatch->n_tokens;
321
-
322
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
323
-
324
- float * data = (float *) kq_mask->data;
325
-
326
- for (int h = 0; h < 1; ++h) {
327
- for (int s1 = 0; s1 < n_seqs; ++s1) {
328
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
329
-
330
- for (int j = 0; j < n_seq_tokens; ++j) {
331
- const int32_t tj = s1*n_seq_tokens + j;
332
-
333
- for (int s0 = 0; s0 < n_seqs; ++s0) {
334
- for (int i = 0; i < n_seq_tokens; ++i) {
335
- const int32_t ti = s0*n_seq_tokens + i;
336
- float f = -INFINITY;
337
-
338
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
339
- if (ubatch->seq_id[s0][s] == seq_id) {
340
- if (hparams.use_alibi) {
341
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
342
- } else {
343
- f = 0.0f;
344
- }
345
- break;
346
- }
347
- }
348
-
349
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
350
- }
351
- }
248
+ const int64_t n_kv = ubatch->n_tokens;
249
+ const int64_t n_tokens = ubatch->n_tokens;
250
+
251
+ GGML_ASSERT(kq_mask);
252
+ GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
253
+
254
+ float * data = (float *) kq_mask->data;
255
+
256
+ for (int h = 0; h < 1; ++h) {
257
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
258
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
259
+
260
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
261
+ float f = -INFINITY;
352
262
 
353
- for (int i = n_tokens; i < n_stride; ++i) {
354
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
263
+ for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
264
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
265
+
266
+ // TODO: reimplement this like in llama_kv_cache_unified
267
+ if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
268
+ if (hparams.use_alibi) {
269
+ f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
270
+ } else {
271
+ f = 0.0f;
355
272
  }
273
+ break;
356
274
  }
357
275
  }
276
+
277
+ data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
358
278
  }
359
279
  }
360
280
  }
@@ -362,53 +282,80 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
362
282
 
363
283
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
364
284
  if (self_kq_mask) {
365
- kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
285
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366
286
  }
367
287
  }
368
288
 
369
289
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370
290
  if (self_kq_mask) {
371
- kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
291
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
372
292
  }
373
293
 
374
294
  if (self_kq_mask_swa) {
375
- kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
295
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
376
296
  }
377
297
  }
378
298
 
379
299
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
380
- if (cross_kq_mask) {
381
- const int64_t n_enc = cross_kq_mask->ne[0];
382
- const int64_t n_tokens = ubatch->n_tokens;
300
+ GGML_ASSERT(cross_kq_mask);
383
301
 
384
- GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
385
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
302
+ const int64_t n_enc = cross_kq_mask->ne[0];
303
+ const int64_t n_tokens = ubatch->n_tokens;
386
304
 
387
- float * data = (float *) cross_kq_mask->data;
305
+ GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
306
+ GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
388
307
 
389
- for (int h = 0; h < 1; ++h) {
390
- for (int j = 0; j < n_tokens; ++j) {
391
- for (int i = 0; i < n_enc; ++i) {
392
- float f = -INFINITY;
393
- for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
394
- const llama_seq_id seq_id = ubatch->seq_id[j][s];
395
- if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
396
- f = 0.0f;
397
- }
308
+ float * data = (float *) cross_kq_mask->data;
309
+
310
+ for (int h = 0; h < 1; ++h) {
311
+ for (int i = 0; i < n_tokens; ++i) {
312
+ for (int j = 0; j < n_enc; ++j) {
313
+ float f = -INFINITY;
314
+
315
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
316
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
317
+
318
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
319
+ f = 0.0f;
398
320
  }
399
- data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
400
321
  }
322
+
323
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
401
324
  }
325
+ }
402
326
 
403
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
404
- for (int j = 0; j < n_enc; ++j) {
405
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
406
- }
327
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
328
+ for (int j = 0; j < n_enc; ++j) {
329
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
407
330
  }
408
331
  }
409
332
  }
410
333
  }
411
334
 
335
+ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
336
+ if (self_kq_mask) {
337
+ mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
338
+ }
339
+
340
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
341
+
342
+ if (s_copy) {
343
+ GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
344
+ int32_t * data = (int32_t *) s_copy->data;
345
+
346
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
347
+ for (uint32_t i = 0; i < n_rs; ++i) {
348
+ data[i] = mctx->get_recr()->s_copy(i);
349
+ }
350
+ }
351
+ }
352
+
353
+ void llm_graph_input_one::set_input(const llama_ubatch *) {
354
+ GGML_ASSERT(one && ggml_nelements(one) == 1);
355
+ float f_one = 1.0f;
356
+ ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
357
+ }
358
+
412
359
  //
413
360
  // llm_graph_context
414
361
  //
@@ -448,16 +395,12 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
448
395
  backend_cpu (params.backend_cpu),
449
396
  cvec (params.cvec),
450
397
  loras (params.loras),
451
- memory (params.memory),
398
+ mctx (params.mctx),
452
399
  cross (params.cross),
453
400
  cb_func (params.cb),
454
401
  res (std::make_unique<llm_graph_result>()) {
455
402
  }
456
403
 
457
- int64_t llm_graph_context::n_pos_per_embd() const {
458
- return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
459
- }
460
-
461
404
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
462
405
  if (cb_func) {
463
406
  cb_func(ubatch, cur, name, il);
@@ -617,12 +560,20 @@ ggml_tensor * llm_graph_context::build_ffn(
617
560
 
618
561
  switch (type_op) {
619
562
  case LLM_FFN_SILU:
620
- {
563
+ if (gate && type_gate == LLM_FFN_PAR) {
564
+ cur = ggml_swiglu_split(ctx0, cur, tmp);
565
+ cb(cur, "ffn_swiglu", il);
566
+ type_gate = LLM_FFN_SEQ;
567
+ } else {
621
568
  cur = ggml_silu(ctx0, cur);
622
569
  cb(cur, "ffn_silu", il);
623
570
  } break;
624
571
  case LLM_FFN_GELU:
625
- {
572
+ if (gate && type_gate == LLM_FFN_PAR) {
573
+ cur = ggml_geglu_split(ctx0, cur, tmp);
574
+ cb(cur, "ffn_geglu", il);
575
+ type_gate = LLM_FFN_SEQ;
576
+ } else {
626
577
  cur = ggml_gelu(ctx0, cur);
627
578
  cb(cur, "ffn_gelu", il);
628
579
  if (act_scales != NULL) {
@@ -631,7 +582,11 @@ ggml_tensor * llm_graph_context::build_ffn(
631
582
  }
632
583
  } break;
633
584
  case LLM_FFN_RELU:
634
- {
585
+ if (gate && type_gate == LLM_FFN_PAR) {
586
+ cur = ggml_reglu_split(ctx0, cur, tmp);
587
+ cb(cur, "ffn_reglu", il);
588
+ type_gate = LLM_FFN_SEQ;
589
+ } else {
635
590
  cur = ggml_relu(ctx0, cur);
636
591
  cb(cur, "ffn_relu", il);
637
592
  } break;
@@ -645,16 +600,18 @@ ggml_tensor * llm_graph_context::build_ffn(
645
600
  } break;
646
601
  case LLM_FFN_SWIGLU:
647
602
  {
648
- // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
649
- int64_t split_point = cur->ne[0] / 2;
650
- ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
651
- ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
652
-
653
- x0 = ggml_silu(ctx0, x0);
654
- cb(cur, "ffn_silu", il);
655
-
656
- cur = ggml_mul(ctx0, x0, x1);
657
- cb(cur, "ffn_mul", il);
603
+ cur = ggml_swiglu(ctx0, cur);
604
+ cb(cur, "ffn_swiglu", il);
605
+ } break;
606
+ case LLM_FFN_GEGLU:
607
+ {
608
+ cur = ggml_geglu(ctx0, cur);
609
+ cb(cur, "ffn_geglu", il);
610
+ } break;
611
+ case LLM_FFN_REGLU:
612
+ {
613
+ cur = ggml_reglu(ctx0, cur);
614
+ cb(cur, "ffn_reglu", il);
658
615
  } break;
659
616
  }
660
617
 
@@ -766,9 +723,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
766
723
  cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
767
724
 
768
725
  if (weight_before_ffn) {
769
- // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
770
- ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
771
- repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
726
+ // repeat cur to [n_embd, n_expert_used, n_tokens]
727
+ ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
772
728
  cur = ggml_mul(ctx0, repeated, weights);
773
729
  cb(cur, "ffn_moe_weighted", il);
774
730
  }
@@ -786,12 +742,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
786
742
 
787
743
  switch (type_op) {
788
744
  case LLM_FFN_SILU:
789
- {
745
+ if (gate_exps) {
746
+ cur = ggml_swiglu_split(ctx0, cur, up);
747
+ cb(cur, "ffn_moe_swiglu", il);
748
+ } else {
790
749
  cur = ggml_silu(ctx0, cur);
791
750
  cb(cur, "ffn_moe_silu", il);
792
751
  } break;
793
752
  case LLM_FFN_GELU:
794
- {
753
+ if (gate_exps) {
754
+ cur = ggml_geglu_split(ctx0, cur, up);
755
+ cb(cur, "ffn_moe_geglu", il);
756
+ } else {
795
757
  cur = ggml_gelu(ctx0, cur);
796
758
  cb(cur, "ffn_moe_gelu", il);
797
759
  } break;
@@ -799,11 +761,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
799
761
  GGML_ABORT("fatal error");
800
762
  }
801
763
 
802
- if (gate_exps) {
803
- cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
804
- cb(cur, "ffn_moe_gate_par", il);
805
- }
806
-
807
764
  experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
808
765
  cb(experts, "ffn_moe_down", il);
809
766
 
@@ -888,11 +845,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
888
845
  }
889
846
 
890
847
  ggml_tensor * llm_graph_context::build_inp_pos() const {
891
- auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
848
+ auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
892
849
 
893
850
  auto & cur = inp->pos;
894
851
 
895
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
852
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
896
853
  ggml_set_input(cur);
897
854
 
898
855
  res->add_input(std::move(inp));
@@ -915,6 +872,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
915
872
  }
916
873
 
917
874
  ggml_tensor * llm_graph_context::build_inp_out_ids() const {
875
+ // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
876
+ // but this would make the graph topology depend on the number of output tokens, which can interere with
877
+ // features that require constant topology such as pipline parallelism
878
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
879
+ //if (n_outputs < n_tokens) {
880
+ // return nullptr;
881
+ //}
882
+
918
883
  auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
919
884
 
920
885
  auto & cur = inp->out_ids;
@@ -932,7 +897,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
932
897
 
933
898
  auto & cur = inp->mean;
934
899
 
935
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
900
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
936
901
  ggml_set_input(cur);
937
902
 
938
903
  res->add_input(std::move(inp));
@@ -945,41 +910,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
945
910
 
946
911
  auto & cur = inp->cls;
947
912
 
948
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
949
- ggml_set_input(cur);
950
-
951
- res->add_input(std::move(inp));
952
-
953
- return cur;
954
- }
955
-
956
- ggml_tensor * llm_graph_context::build_inp_s_copy() const {
957
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
958
-
959
- auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
960
-
961
- const auto n_kv = kv_self->n;
962
-
963
- auto & cur = inp->s_copy;
964
-
965
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
966
- ggml_set_input(cur);
967
-
968
- res->add_input(std::move(inp));
969
-
970
- return cur;
971
- }
972
-
973
- ggml_tensor * llm_graph_context::build_inp_s_mask() const {
974
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
975
-
976
- auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
977
-
978
- const auto n_kv = kv_self->n;
979
-
980
- auto & cur = inp->s_mask;
981
-
982
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
913
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
983
914
  ggml_set_input(cur);
984
915
 
985
916
  res->add_input(std::move(inp));
@@ -1025,11 +956,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1025
956
  }
1026
957
 
1027
958
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1028
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
959
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1029
960
 
1030
- auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
961
+ auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
1031
962
 
1032
- const auto n_kv = kv_self->get_n();
963
+ const auto n_kv = mctx_cur->get_n_kv();
1033
964
 
1034
965
  auto & cur = inp->pos_bucket;
1035
966
 
@@ -1056,6 +987,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
1056
987
  return pos_bias;
1057
988
  }
1058
989
 
990
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
991
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
992
+
993
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
994
+
995
+ {
996
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
997
+
998
+ const auto n_kv = inp->mctx->get_attn()->get_n_kv();
999
+
1000
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1001
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1002
+ ggml_set_input(inp->self_kq_mask);
1003
+
1004
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1005
+ }
1006
+
1007
+ {
1008
+ const auto n_rs = mctx_cur->get_recr()->get_n_rs();
1009
+
1010
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1011
+ ggml_set_input(inp->s_copy);
1012
+ }
1013
+
1014
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1015
+ }
1016
+
1059
1017
  ggml_tensor * llm_graph_context::build_attn_mha(
1060
1018
  ggml_cgraph * gf,
1061
1019
  ggml_tensor * q,
@@ -1231,14 +1189,14 @@ ggml_tensor * llm_graph_context::build_attn(
1231
1189
  }
1232
1190
 
1233
1191
  llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1234
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1192
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1235
1193
 
1236
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1194
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1237
1195
 
1238
1196
  {
1239
1197
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1240
1198
 
1241
- const auto n_kv = kv_self->get_n();
1199
+ const auto n_kv = mctx_cur->get_n_kv();
1242
1200
 
1243
1201
  inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1244
1202
  //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1268,19 +1226,19 @@ ggml_tensor * llm_graph_context::build_attn(
1268
1226
  ggml_build_forward_expand(gf, k_cur);
1269
1227
  ggml_build_forward_expand(gf, v_cur);
1270
1228
 
1271
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1229
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1272
1230
 
1273
1231
  // store to KV cache
1274
1232
  {
1275
- ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1276
- ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1233
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1234
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1277
1235
  }
1278
1236
 
1279
1237
  const auto & kq_mask = inp->get_kq_mask();
1280
1238
 
1281
1239
  ggml_tensor * q = q_cur;
1282
- ggml_tensor * k = kv_self->get_k(ctx0, il);
1283
- ggml_tensor * v = kv_self->get_v(ctx0, il);
1240
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1241
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1284
1242
 
1285
1243
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1286
1244
  cb(cur, "kqv_out", il);
@@ -1300,36 +1258,6 @@ ggml_tensor * llm_graph_context::build_attn(
1300
1258
  return cur;
1301
1259
  }
1302
1260
 
1303
- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1304
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1305
-
1306
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1307
-
1308
- {
1309
- const auto n_kv = kv_self->get_kv_base()->get_n();
1310
-
1311
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1312
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1313
- ggml_set_input(inp->self_kq_mask);
1314
-
1315
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1316
- }
1317
-
1318
- {
1319
- GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1320
-
1321
- const auto n_kv = kv_self->get_kv_swa()->get_n();
1322
-
1323
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1324
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1325
- ggml_set_input(inp->self_kq_mask_swa);
1326
-
1327
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1328
- }
1329
-
1330
- return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1331
- }
1332
-
1333
1261
  ggml_tensor * llm_graph_context::build_attn(
1334
1262
  llm_graph_input_attn_kv_unified_iswa * inp,
1335
1263
  ggml_cgraph * gf,
@@ -1345,26 +1273,35 @@ ggml_tensor * llm_graph_context::build_attn(
1345
1273
  // these nodes are added to the graph together so that they are not reordered
1346
1274
  // by doing so, the number of splits in the graph is reduced
1347
1275
  ggml_build_forward_expand(gf, q_cur);
1348
- ggml_build_forward_expand(gf, k_cur);
1349
- ggml_build_forward_expand(gf, v_cur);
1276
+
1277
+ if (k_cur) {
1278
+ ggml_build_forward_expand(gf, k_cur);
1279
+ }
1280
+
1281
+ if (v_cur) {
1282
+ ggml_build_forward_expand(gf, v_cur);
1283
+ }
1284
+
1285
+ const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1350
1286
 
1351
1287
  const bool is_swa = hparams.is_swa(il);
1352
1288
 
1353
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1289
+ const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
1354
1290
 
1355
- const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
1291
+ // optionally store to KV cache
1292
+ if (k_cur) {
1293
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1294
+ }
1356
1295
 
1357
- // store to KV cache
1358
- {
1359
- ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1360
- ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
1296
+ if (v_cur) {
1297
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1361
1298
  }
1362
1299
 
1363
1300
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1364
1301
 
1365
1302
  ggml_tensor * q = q_cur;
1366
- ggml_tensor * k = kv->get_k(ctx0, il);
1367
- ggml_tensor * v = kv->get_v(ctx0, il);
1303
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1304
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1368
1305
 
1369
1306
  ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1370
1307
  cb(cur, "kqv_out", il);
@@ -1439,56 +1376,182 @@ ggml_tensor * llm_graph_context::build_attn(
1439
1376
  return cur;
1440
1377
  }
1441
1378
 
1442
- ggml_tensor * llm_graph_context::build_copy_mask_state(
1443
- ggml_cgraph * gf,
1444
- ggml_tensor * s,
1445
- ggml_tensor * state_copy,
1446
- ggml_tensor * state_mask,
1447
- int32_t n_state,
1448
- int32_t n_seqs) const {
1449
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1379
+ ggml_tensor * llm_graph_context::build_attn(
1380
+ llm_graph_input_mem_hybrid * inp,
1381
+ ggml_cgraph * gf,
1382
+ ggml_tensor * wo,
1383
+ ggml_tensor * wo_b,
1384
+ ggml_tensor * q_cur,
1385
+ ggml_tensor * k_cur,
1386
+ ggml_tensor * v_cur,
1387
+ ggml_tensor * kq_b,
1388
+ ggml_tensor * v_mla,
1389
+ float kq_scale,
1390
+ int il) const {
1391
+ // these nodes are added to the graph together so that they are not reordered
1392
+ // by doing so, the number of splits in the graph is reduced
1393
+ ggml_build_forward_expand(gf, q_cur);
1394
+ ggml_build_forward_expand(gf, k_cur);
1395
+ ggml_build_forward_expand(gf, v_cur);
1396
+
1397
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
1398
+
1399
+ // store to KV cache
1400
+ {
1401
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1402
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1403
+ }
1404
+
1405
+ const auto & kq_mask = inp->get_kq_mask();
1450
1406
 
1451
- const auto n_kv = kv_self->n;
1452
- const auto kv_head = kv_self->head;
1407
+ ggml_tensor * q = q_cur;
1408
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1409
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1410
+
1411
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1412
+ cb(cur, "kqv_out", il);
1413
+
1414
+ if (wo) {
1415
+ cur = build_lora_mm(wo, cur);
1416
+ if (arch == LLM_ARCH_GLM4) {
1417
+ // GLM4 seems to have numerical issues with half-precision accumulators
1418
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1419
+ }
1420
+ }
1453
1421
 
1454
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
1422
+ if (wo_b) {
1423
+ cur = ggml_add(ctx0, cur, wo_b);
1424
+ }
1455
1425
 
1456
- // copy states
1457
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1458
- // this shrinks the tensors's ne[1] to n_kv
1459
- states = ggml_get_rows(ctx0, states, state_copy);
1426
+ return cur;
1427
+ }
1460
1428
 
1461
- // clear states of sequences which are starting at the beginning of this batch
1462
- // FIXME: zero-out NANs?
1463
- states = ggml_mul(ctx0, states, state_mask);
1429
+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1430
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1464
1431
 
1465
- // copy states which won't be changed further (between n_seqs and n_kv)
1432
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1433
+
1434
+ {
1435
+ const auto n_kv = mctx_cur->get_base()->get_n_kv();
1436
+
1437
+ inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1438
+ //cb(inp->self_kq_mask, "KQ_mask", -1);
1439
+ ggml_set_input(inp->self_kq_mask);
1440
+
1441
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1442
+ }
1443
+
1444
+ {
1445
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1446
+
1447
+ const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1448
+
1449
+ inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1450
+ //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1451
+ ggml_set_input(inp->self_kq_mask_swa);
1452
+
1453
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1454
+ }
1455
+
1456
+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1457
+ }
1458
+
1459
+ ggml_tensor * llm_graph_context::build_rs(
1460
+ ggml_cgraph * gf,
1461
+ ggml_tensor * s,
1462
+ ggml_tensor * state_copy,
1463
+ int32_t state_size,
1464
+ int32_t n_seqs,
1465
+ uint32_t n_kv,
1466
+ uint32_t kv_head,
1467
+ uint32_t kv_size,
1468
+ int32_t rs_zero,
1469
+ bool avoid_copies) const {
1470
+
1471
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1472
+
1473
+ // Clear a single state which will then be copied to the other cleared states.
1474
+ // Note that this is a no-op when the view is zero-sized.
1475
+ ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1476
+ ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1477
+
1478
+ ggml_tensor * output_states;
1479
+
1480
+ if (!avoid_copies) {
1481
+ // copy states
1482
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1483
+ // {state_size, kv_size} -> {state_size, n_seqs}
1484
+ output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1485
+ ggml_build_forward_expand(gf, output_states);
1486
+ } else {
1487
+ // FIXME: make the gathering operation happen before the copy below
1488
+ // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1489
+ output_states = states;
1490
+ }
1491
+
1492
+ // copy extra states which won't be changed further (between n_seqs and n_kv)
1493
+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1466
1494
  ggml_build_forward_expand(gf,
1467
1495
  ggml_cpy(ctx0,
1468
- ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
1469
- ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1496
+ states_extra,
1497
+ ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1498
+
1499
+ return output_states;
1500
+ }
1501
+
1502
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1503
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1504
+
1505
+ auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1506
+
1507
+ const auto n_rs = mctx_cur->get_n_rs();
1508
+
1509
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1510
+ ggml_set_input(inp->s_copy);
1470
1511
 
1471
- // the part of the states that will be used and modified
1472
- return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
1512
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1513
+ }
1514
+
1515
+ ggml_tensor * llm_graph_context::build_rs(
1516
+ llm_graph_input_rs * inp,
1517
+ ggml_cgraph * gf,
1518
+ ggml_tensor * s,
1519
+ int32_t state_size,
1520
+ int32_t n_seqs,
1521
+ bool avoid_copies) const {
1522
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1523
+
1524
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1525
+ }
1526
+
1527
+ ggml_tensor * llm_graph_context::build_rs(
1528
+ llm_graph_input_mem_hybrid * inp,
1529
+ ggml_cgraph * gf,
1530
+ ggml_tensor * s,
1531
+ int32_t state_size,
1532
+ int32_t n_seqs,
1533
+ bool avoid_copies) const {
1534
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1535
+
1536
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1473
1537
  }
1474
1538
 
1475
1539
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1476
- ggml_cgraph * gf,
1477
- ggml_tensor * state_copy,
1478
- ggml_tensor * state_mask,
1479
- const llama_ubatch & ubatch,
1540
+ llm_graph_input_rs * inp,
1541
+ ggml_cgraph * gf,
1542
+ const llama_ubatch & ubatch,
1480
1543
  int il) const {
1481
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1544
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1482
1545
 
1483
1546
  const auto token_shift_count = hparams.token_shift_count;
1484
1547
 
1485
1548
  const int64_t n_seqs = ubatch.n_seqs;
1486
1549
 
1487
- ggml_tensor * token_shift_all = kv_self->k_l[il];
1550
+ ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1488
1551
 
1489
- ggml_tensor * token_shift = build_copy_mask_state(
1490
- gf, token_shift_all, state_copy, state_mask,
1491
- hparams.n_embd_k_s(), n_seqs);
1552
+ ggml_tensor * token_shift = build_rs(
1553
+ inp, gf, token_shift_all,
1554
+ hparams.n_embd_r(), n_seqs);
1492
1555
 
1493
1556
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
1494
1557
 
@@ -1499,19 +1562,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1499
1562
  ggml_tensor * token_shift,
1500
1563
  const llama_ubatch & ubatch,
1501
1564
  int il) const {
1502
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1565
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1503
1566
 
1504
1567
  const auto token_shift_count = hparams.token_shift_count;
1505
1568
  const auto n_embd = hparams.n_embd;
1506
1569
 
1507
1570
  const int64_t n_seqs = ubatch.n_seqs;
1508
1571
 
1509
- const auto kv_head = kv_self->head;
1572
+ const auto kv_head = mctx_cur->get_head();
1510
1573
 
1511
1574
  return ggml_cpy(
1512
1575
  ctx0,
1513
1576
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1514
- ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
1577
+ ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
1515
1578
  );
1516
1579
  }
1517
1580
 
@@ -1562,20 +1625,32 @@ void llm_graph_context::build_pooling(
1562
1625
  ggml_tensor * inp_cls = build_inp_cls();
1563
1626
  inp = ggml_get_rows(ctx0, inp, inp_cls);
1564
1627
 
1565
- // classification head
1566
- // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1567
- GGML_ASSERT(cls != nullptr);
1568
- GGML_ASSERT(cls_b != nullptr);
1569
-
1570
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1571
- cur = ggml_tanh(ctx0, cur);
1572
-
1573
- // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1574
- // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1575
- if (cls_out) {
1576
- GGML_ASSERT(cls_out_b != nullptr);
1577
-
1578
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1628
+ if (cls) {
1629
+ // classification head
1630
+ // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1631
+ cur = ggml_mul_mat(ctx0, cls, inp);
1632
+ if (cls_b) {
1633
+ cur = ggml_add(ctx0, cur, cls_b);
1634
+ }
1635
+ cur = ggml_tanh(ctx0, cur);
1636
+
1637
+ // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1638
+ // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1639
+ if (cls_out) {
1640
+ cur = ggml_mul_mat(ctx0, cls_out, cur);
1641
+ if (cls_out_b) {
1642
+ cur = ggml_add(ctx0, cur, cls_out_b);
1643
+ }
1644
+ }
1645
+ } else if (cls_out) {
1646
+ // Single layer classification head (direct projection)
1647
+ // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1648
+ cur = ggml_mul_mat(ctx0, cls_out, inp);
1649
+ if (cls_out_b) {
1650
+ cur = ggml_add(ctx0, cur, cls_out_b);
1651
+ }
1652
+ } else {
1653
+ GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
1579
1654
  }
1580
1655
  } break;
1581
1656
  default: