whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -17,10 +17,12 @@ struct ggml_tensor;
17
17
  struct llama_ubatch;
18
18
  struct llama_cparams;
19
19
 
20
- class llama_memory_i;
21
- class llama_kv_cache_unified;
22
- class llama_kv_cache_unified_iswa;
23
- class llama_kv_cache_recurrent;
20
+ struct llama_memory_context_i;
21
+
22
+ class llama_kv_cache_unified_context;
23
+ class llama_kv_cache_unified_iswa_context;
24
+ class llama_memory_recurrent_context;
25
+ class llama_memory_hybrid_context;
24
26
 
25
27
  // certain models (typically multi-modal) can produce different types of graphs
26
28
  enum llm_graph_type {
@@ -35,6 +37,8 @@ enum llm_ffn_op_type {
35
37
  LLM_FFN_RELU,
36
38
  LLM_FFN_RELU_SQR,
37
39
  LLM_FFN_SWIGLU,
40
+ LLM_FFN_GEGLU,
41
+ LLM_FFN_REGLU,
38
42
  };
39
43
 
40
44
  enum llm_ffn_gate_type {
@@ -92,14 +96,14 @@ public:
92
96
 
93
97
  class llm_graph_input_pos : public llm_graph_input_i {
94
98
  public:
95
- llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
99
+ llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
96
100
  virtual ~llm_graph_input_pos() = default;
97
101
 
98
102
  void set_input(const llama_ubatch * ubatch) override;
99
103
 
100
104
  ggml_tensor * pos = nullptr; // I32 [n_batch]
101
105
 
102
- const int64_t n_pos_per_embd = 1;
106
+ const uint32_t n_pos_per_embd = 1;
103
107
  };
104
108
 
105
109
  // temperature tuning, used by llama4
@@ -133,7 +137,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
133
137
  public:
134
138
  llm_graph_input_pos_bucket_kv(
135
139
  const llama_hparams & hparams,
136
- const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
140
+ const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
137
141
  virtual ~llm_graph_input_pos_bucket_kv() = default;
138
142
 
139
143
  void set_input(const llama_ubatch * ubatch) override;
@@ -141,7 +145,8 @@ public:
141
145
  ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
142
146
 
143
147
  const llama_hparams & hparams;
144
- const llama_kv_cache_unified * kv_self;
148
+
149
+ const llama_kv_cache_unified_context * mctx;
145
150
  };
146
151
 
147
152
  class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -186,28 +191,16 @@ public:
186
191
  const llama_cparams & cparams;
187
192
  };
188
193
 
189
- class llm_graph_input_s_copy : public llm_graph_input_i {
194
+ class llm_graph_input_rs : public llm_graph_input_i {
190
195
  public:
191
- llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
192
- virtual ~llm_graph_input_s_copy() = default;
196
+ llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
197
+ virtual ~llm_graph_input_rs() = default;
193
198
 
194
199
  void set_input(const llama_ubatch * ubatch) override;
195
200
 
196
201
  ggml_tensor * s_copy; // I32 [kv_size]
197
202
 
198
- const llama_kv_cache_recurrent * kv_self;
199
- };
200
-
201
- class llm_graph_input_s_mask : public llm_graph_input_i {
202
- public:
203
- llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
204
- virtual ~llm_graph_input_s_mask() = default;
205
-
206
- void set_input(const llama_ubatch * ubatch) override;
207
-
208
- ggml_tensor * s_mask; // F32 [1, n_kv]
209
-
210
- const llama_kv_cache_recurrent * kv_self;
203
+ const llama_memory_recurrent_context * mctx;
211
204
  };
212
205
 
213
206
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -247,10 +240,10 @@ public:
247
240
  llm_graph_input_attn_kv_unified(
248
241
  const llama_hparams & hparams,
249
242
  const llama_cparams & cparams,
250
- const llama_kv_cache_unified * kv_self) :
243
+ const llama_kv_cache_unified_context * mctx) :
251
244
  hparams(hparams),
252
245
  cparams(cparams),
253
- kv_self(kv_self) {
246
+ mctx(mctx) {
254
247
  }
255
248
  ~llm_graph_input_attn_kv_unified() = default;
256
249
 
@@ -264,7 +257,7 @@ public:
264
257
  const llama_hparams & hparams;
265
258
  const llama_cparams & cparams;
266
259
 
267
- const llama_kv_cache_unified * kv_self;
260
+ const llama_kv_cache_unified_context * mctx;
268
261
  };
269
262
 
270
263
  class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@@ -272,10 +265,10 @@ public:
272
265
  llm_graph_input_attn_kv_unified_iswa(
273
266
  const llama_hparams & hparams,
274
267
  const llama_cparams & cparams,
275
- const llama_kv_cache_unified_iswa * kv_self) :
268
+ const llama_kv_cache_unified_iswa_context * mctx) :
276
269
  hparams(hparams),
277
270
  cparams(cparams),
278
- kv_self(kv_self) {
271
+ mctx(mctx) {
279
272
  }
280
273
  ~llm_graph_input_attn_kv_unified_iswa() = default;
281
274
 
@@ -292,7 +285,7 @@ public:
292
285
  const llama_hparams & hparams;
293
286
  const llama_cparams & cparams;
294
287
 
295
- const llama_kv_cache_unified_iswa * kv_self;
288
+ const llama_kv_cache_unified_iswa_context * mctx;
296
289
  };
297
290
 
298
291
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -310,6 +303,44 @@ public:
310
303
  const llama_cross * cross = nullptr;
311
304
  };
312
305
 
306
+ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
307
+ public:
308
+ llm_graph_input_mem_hybrid(
309
+ const llama_hparams & hparams,
310
+ const llama_cparams & cparams,
311
+ const llama_memory_hybrid_context * mctx) :
312
+ hparams(hparams),
313
+ cparams(cparams),
314
+ mctx(mctx) {
315
+ }
316
+ virtual ~llm_graph_input_mem_hybrid() = default;
317
+
318
+ void set_input(const llama_ubatch * ubatch) override;
319
+
320
+ ggml_tensor * s_copy; // I32 [kv_size]
321
+
322
+ ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
323
+
324
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
325
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
326
+
327
+ const llama_hparams & hparams;
328
+ const llama_cparams & cparams;
329
+
330
+ const llama_memory_hybrid_context * mctx;
331
+ };
332
+
333
+ // TODO: remove this when ggml_scale_add is implemented
334
+ class llm_graph_input_one : public llm_graph_input_i {
335
+ public:
336
+ llm_graph_input_one() {}
337
+ virtual ~llm_graph_input_one() = default;
338
+
339
+ void set_input(const llama_ubatch *) override;
340
+
341
+ ggml_tensor * one = nullptr; // F32
342
+ };
343
+
313
344
  //
314
345
  // llm_graph_result
315
346
  //
@@ -383,12 +414,12 @@ struct llm_graph_params {
383
414
  ggml_backend_sched_t sched;
384
415
  ggml_backend_t backend_cpu;
385
416
 
386
- const llama_adapter_cvec * cvec;
387
- const llama_adapter_loras * loras;
388
- const llama_memory_i * memory;
389
- const llama_cross * cross;
417
+ const llama_adapter_cvec * cvec;
418
+ const llama_adapter_loras * loras;
419
+ const llama_memory_context_i * mctx;
420
+ const llama_cross * cross;
390
421
 
391
- int32_t n_outputs;
422
+ uint32_t n_outputs;
392
423
 
393
424
  const llm_graph_cb & cb;
394
425
  };
@@ -422,8 +453,8 @@ struct llm_graph_context {
422
453
  const float norm_eps;
423
454
  const float norm_rms_eps;
424
455
 
425
- const int32_t n_tokens;
426
- const int32_t n_outputs;
456
+ const int64_t n_tokens;
457
+ const int64_t n_outputs;
427
458
  const int32_t n_ctx_orig; // yarn
428
459
 
429
460
  const enum llama_pooling_type pooling_type;
@@ -435,18 +466,17 @@ struct llm_graph_context {
435
466
 
436
467
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
437
468
 
438
- const llama_adapter_cvec * cvec;
439
- const llama_adapter_loras * loras;
440
- const llama_memory_i * memory;
441
- const llama_cross * cross;
469
+ const llama_adapter_cvec * cvec;
470
+ const llama_adapter_loras * loras;
471
+ const llama_memory_context_i * mctx;
472
+ const llama_cross * cross;
442
473
 
443
474
  const llm_graph_cb & cb_func;
444
475
 
445
476
  std::unique_ptr<llm_graph_result> res;
446
477
 
447
478
  llm_graph_context(const llm_graph_params & params);
448
-
449
- int64_t n_pos_per_embd() const;
479
+ virtual ~llm_graph_context() = default;
450
480
 
451
481
  void cb(ggml_tensor * cur, const char * name, int il) const;
452
482
 
@@ -518,14 +548,14 @@ struct llm_graph_context {
518
548
  ggml_tensor * build_inp_out_ids() const;
519
549
  ggml_tensor * build_inp_mean() const;
520
550
  ggml_tensor * build_inp_cls() const;
521
- ggml_tensor * build_inp_s_copy() const;
522
- ggml_tensor * build_inp_s_mask() const;
523
551
 
524
552
  ggml_tensor * build_inp_cross_embd() const;
525
553
  ggml_tensor * build_inp_pos_bucket_enc() const;
526
554
  ggml_tensor * build_inp_pos_bucket_dec() const;
527
555
  ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
528
556
 
557
+ llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
558
+
529
559
  //
530
560
  // attention
531
561
  //
@@ -572,14 +602,15 @@ struct llm_graph_context {
572
602
 
573
603
  llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
574
604
 
605
+ // note: if k_cur or v_cur are not provided, they will not be stored in the memory
575
606
  ggml_tensor * build_attn(
576
607
  llm_graph_input_attn_kv_unified_iswa * inp,
577
608
  ggml_cgraph * gf,
578
609
  ggml_tensor * wo,
579
610
  ggml_tensor * wo_b,
580
611
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
581
- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
582
- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
612
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
613
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
583
614
  ggml_tensor * kq_b,
584
615
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
585
616
  float kq_scale,
@@ -600,23 +631,62 @@ struct llm_graph_context {
600
631
  float kq_scale,
601
632
  int il) const;
602
633
 
634
+ ggml_tensor * build_attn(
635
+ llm_graph_input_mem_hybrid * inp,
636
+ ggml_cgraph * gf,
637
+ ggml_tensor * wo,
638
+ ggml_tensor * wo_b,
639
+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
640
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
641
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
642
+ ggml_tensor * kq_b,
643
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
644
+ float kq_scale,
645
+ int il) const;
603
646
  //
604
647
  // recurrent
605
648
  //
606
649
 
607
- ggml_tensor * build_copy_mask_state(
608
- ggml_cgraph * gf,
609
- ggml_tensor * s,
610
- ggml_tensor * state_copy,
611
- ggml_tensor * state_mask,
612
- int32_t n_state,
613
- int32_t n_seqs) const;
650
+ // TODO: avoid notion of "kv"
651
+ // TODO: move this implementation to llama_memory_recurrent.
652
+ // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
653
+ // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
654
+ // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
655
+ // `llama_memory_recurrent`
656
+ ggml_tensor * build_rs(
657
+ ggml_cgraph * gf,
658
+ ggml_tensor * s,
659
+ ggml_tensor * state_copy,
660
+ int32_t state_size,
661
+ int32_t n_seqs,
662
+ uint32_t n_kv,
663
+ uint32_t kv_head,
664
+ uint32_t kv_size,
665
+ int32_t rs_zero,
666
+ bool avoid_copies = false) const;
667
+
668
+ llm_graph_input_rs * build_rs_inp() const;
669
+
670
+ ggml_tensor * build_rs(
671
+ llm_graph_input_rs * inp,
672
+ ggml_cgraph * gf,
673
+ ggml_tensor * s,
674
+ int32_t state_size,
675
+ int32_t n_seqs,
676
+ bool avoid_copies = false) const;
677
+
678
+ ggml_tensor * build_rs(
679
+ llm_graph_input_mem_hybrid * inp,
680
+ ggml_cgraph * gf,
681
+ ggml_tensor * s,
682
+ int32_t state_size,
683
+ int32_t n_seqs,
684
+ bool avoid_copies = false) const;
614
685
 
615
686
  ggml_tensor * build_rwkv_token_shift_load(
616
- ggml_cgraph * gf,
617
- ggml_tensor * state_copy,
618
- ggml_tensor * state_mask,
619
- const llama_ubatch & ubatch,
687
+ llm_graph_input_rs * inp,
688
+ ggml_cgraph * gf,
689
+ const llama_ubatch & ubatch,
620
690
  int il) const;
621
691
 
622
692
  ggml_tensor * build_rwkv_token_shift_store(
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
65
65
  return n_embd_head_v * n_head_kv;
66
66
  }
67
67
 
68
- uint32_t llama_hparams::n_embd_k_s() const {
68
+ uint32_t llama_hparams::n_embd_r() const {
69
69
  if (wkv_head_size != 0) {
70
70
  // for RWKV models
71
71
  return token_shift_count * n_embd;
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
76
76
  return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
77
77
  }
78
78
 
79
- uint32_t llama_hparams::n_embd_v_s() const {
79
+ uint32_t llama_hparams::n_embd_s() const {
80
80
  if (wkv_head_size != 0) {
81
81
  // corresponds to RWKV's wkv_states size
82
82
  return n_embd * wkv_head_size;
@@ -86,6 +86,14 @@ uint32_t llama_hparams::n_embd_v_s() const {
86
86
  return ssm_d_state * ssm_d_inner;
87
87
  }
88
88
 
89
+ bool llama_hparams::is_recurrent(uint32_t il) const {
90
+ return recurrent_layer_arr[il];
91
+ }
92
+
93
+ uint32_t llama_hparams::n_pos_per_embd() const {
94
+ return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
95
+ }
96
+
89
97
  bool llama_hparams::is_swa(uint32_t il) const {
90
98
  if (il < n_layer) {
91
99
  return swa_layers[il];
@@ -115,6 +115,9 @@ struct llama_hparams {
115
115
  uint32_t ssm_d_state = 0;
116
116
  uint32_t ssm_dt_rank = 0;
117
117
 
118
+ // for hybrid state space models
119
+ std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
120
+
118
121
  bool ssm_dt_b_c_rms = false;
119
122
 
120
123
  float f_clamp_kqv = 0.0f;
@@ -131,12 +134,21 @@ struct llama_hparams {
131
134
  bool attn_soft_cap = false;
132
135
  bool use_kq_norm = true;
133
136
 
137
+ // for Classifiers
138
+ uint32_t n_cls_out = 1;
139
+
134
140
  // llama4
135
141
  uint32_t n_moe_layer_step = 0;
136
142
  uint32_t n_no_rope_layer_step = 4;
137
143
  uint32_t n_attn_temp_floor_scale = 8192;
138
144
  float f_attn_temp_scale = 0.1;
139
145
 
146
+ // gemma3n altup
147
+ uint32_t n_altup = 4; // altup_num_inputs
148
+ uint32_t i_altup_act = 0; // altup_active_idx
149
+ uint32_t laurel_rank = 64;
150
+ uint32_t n_embd_altup = 256;
151
+
140
152
  // needed by encoder-decoder models (e.g. T5, FLAN-T5)
141
153
  // ref: https://github.com/ggerganov/llama.cpp/pull/8141
142
154
  llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
@@ -178,10 +190,15 @@ struct llama_hparams {
178
190
 
179
191
  // dimension of the rolling state embeddings
180
192
  // corresponds to Mamba's conv_states size or RWKV's token_shift states size
181
- uint32_t n_embd_k_s() const;
193
+ uint32_t n_embd_r() const;
182
194
 
183
195
  // dimension of the recurrent state embeddings
184
- uint32_t n_embd_v_s() const;
196
+ uint32_t n_embd_s() const;
197
+
198
+ // whether or not the given layer is recurrent (for hybrid models)
199
+ bool is_recurrent(uint32_t il) const;
200
+
201
+ uint32_t n_pos_per_embd() const;
185
202
 
186
203
  bool is_swa(uint32_t il) const;
187
204
  };
@@ -0,0 +1,279 @@
1
+ #include "llama-kv-cache-unified-iswa.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-batch.h"
5
+ #include "llama-model.h"
6
+
7
+ #include <algorithm>
8
+ #include <cassert>
9
+
10
+ //
11
+ // llama_kv_cache_unified_iswa
12
+ //
13
+
14
+ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15
+ const llama_model & model,
16
+ ggml_type type_k,
17
+ ggml_type type_v,
18
+ bool v_trans,
19
+ bool offload,
20
+ bool swa_full,
21
+ uint32_t kv_size,
22
+ uint32_t n_seq_max,
23
+ uint32_t n_ubatch,
24
+ uint32_t n_pad) : hparams(model.hparams) {
25
+ llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
26
+ llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
27
+
28
+ const uint32_t size_base = kv_size;
29
+
30
+ uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
31
+
32
+ // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
33
+ if (swa_full) {
34
+ LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
35
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
36
+
37
+ size_swa = size_base;
38
+ }
39
+
40
+ LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
41
+
42
+ kv_base = std::make_unique<llama_kv_cache_unified>(
43
+ model, std::move(filter_base), type_k, type_v,
44
+ v_trans, offload, size_base, n_seq_max, n_pad,
45
+ 0, LLAMA_SWA_TYPE_NONE);
46
+
47
+ LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
48
+
49
+ kv_swa = std::make_unique<llama_kv_cache_unified>(
50
+ model, std::move(filter_swa), type_k, type_v,
51
+ v_trans, offload, size_swa, n_seq_max, n_pad,
52
+ hparams.n_swa, hparams.swa_type);
53
+ }
54
+
55
+ void llama_kv_cache_unified_iswa::clear(bool data) {
56
+ kv_base->clear(data);
57
+ kv_swa ->clear(data);
58
+ }
59
+
60
+ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
61
+ bool res = true;
62
+
63
+ res = res & kv_base->seq_rm(seq_id, p0, p1);
64
+ res = res & kv_swa ->seq_rm(seq_id, p0, p1);
65
+
66
+ return res;
67
+ }
68
+
69
+ void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
70
+ kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
71
+ kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
72
+ }
73
+
74
+ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
75
+ kv_base->seq_keep(seq_id);
76
+ kv_swa ->seq_keep(seq_id);
77
+ }
78
+
79
+ void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
80
+ kv_base->seq_add(seq_id, p0, p1, shift);
81
+ kv_swa ->seq_add(seq_id, p0, p1, shift);
82
+ }
83
+
84
+ void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
85
+ kv_base->seq_div(seq_id, p0, p1, d);
86
+ kv_swa ->seq_div(seq_id, p0, p1, d);
87
+ }
88
+
89
+ llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
90
+ // the base cache is a superset of the SWA cache, so we can just check the SWA cache
91
+ return kv_swa->seq_pos_min(seq_id);
92
+ }
93
+
94
+ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
95
+ return kv_swa->seq_pos_max(seq_id);
96
+ }
97
+
98
+ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
99
+ GGML_UNUSED(embd_all);
100
+
101
+ // first try simple split
102
+ do {
103
+ balloc.split_reset();
104
+
105
+ std::vector<llama_ubatch> ubatches;
106
+ while (true) {
107
+ auto ubatch = balloc.split_simple(n_ubatch);
108
+
109
+ if (ubatch.n_tokens == 0) {
110
+ break;
111
+ }
112
+
113
+ ubatches.push_back(std::move(ubatch)); // NOLINT
114
+ }
115
+
116
+ auto heads_base = kv_base->prepare(ubatches);
117
+ if (heads_base.empty()) {
118
+ break;
119
+ }
120
+
121
+ auto heads_swa = kv_swa->prepare(ubatches);
122
+ if (heads_swa.empty()) {
123
+ break;
124
+ }
125
+
126
+ assert(heads_base.size() == heads_swa.size());
127
+
128
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(
129
+ this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
130
+ } while (false);
131
+
132
+ // if it fails, try equal split
133
+ do {
134
+ balloc.split_reset();
135
+
136
+ std::vector<llama_ubatch> ubatches;
137
+ while (true) {
138
+ auto ubatch = balloc.split_equal(n_ubatch);
139
+
140
+ if (ubatch.n_tokens == 0) {
141
+ break;
142
+ }
143
+
144
+ ubatches.push_back(std::move(ubatch)); // NOLINT
145
+ }
146
+
147
+ auto heads_base = kv_base->prepare(ubatches);
148
+ if (heads_base.empty()) {
149
+ break;
150
+ }
151
+
152
+ auto heads_swa = kv_swa->prepare(ubatches);
153
+ if (heads_swa.empty()) {
154
+ break;
155
+ }
156
+
157
+ assert(heads_base.size() == heads_swa.size());
158
+
159
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(
160
+ this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
161
+ } while (false);
162
+
163
+ // TODO: if we fail again, we should attempt different splitting strategies
164
+ // but to do that properly, we first have to refactor the batches to be more flexible
165
+
166
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
167
+ }
168
+
169
+ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
170
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
171
+ }
172
+
173
+ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
174
+ return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
175
+ }
176
+
177
+ bool llama_kv_cache_unified_iswa::get_can_shift() const {
178
+ return kv_base->get_size() == kv_swa->get_size();
179
+ }
180
+
181
+ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
182
+ kv_base->state_write(io, seq_id);
183
+ kv_swa ->state_write(io, seq_id);
184
+ }
185
+
186
+ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
187
+ kv_base->state_read(io, seq_id);
188
+ kv_swa ->state_read(io, seq_id);
189
+ }
190
+
191
+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
192
+ return kv_base.get();
193
+ }
194
+
195
+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
196
+ return kv_swa.get();
197
+ }
198
+
199
+ //
200
+ // llama_kv_cache_unified_iswa_context
201
+ //
202
+
203
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
204
+
205
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
206
+ llama_kv_cache_unified_iswa * kv) :
207
+ ctx_base(kv->get_base()->init_full()),
208
+ ctx_swa (kv->get_swa ()->init_full()),
209
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
210
+ }
211
+
212
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
213
+ llama_kv_cache_unified_iswa * kv,
214
+ llama_context * lctx,
215
+ bool optimize) :
216
+ ctx_base(kv->get_base()->init_update(lctx, optimize)),
217
+ ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
218
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
219
+ }
220
+
221
+ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
222
+ llama_kv_cache_unified_iswa * kv,
223
+ std::vector<uint32_t> heads_base,
224
+ std::vector<uint32_t> heads_swa,
225
+ std::vector<llama_ubatch> ubatches) :
226
+ ubatches(std::move(ubatches)),
227
+ // note: here we copy the ubatches. not sure if this is ideal
228
+ ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
229
+ ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
230
+ status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
231
+ }
232
+
233
+ llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
234
+
235
+ bool llama_kv_cache_unified_iswa_context::next() {
236
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
237
+
238
+ ctx_base->next();
239
+ ctx_swa ->next();
240
+
241
+ if (++i_next >= ubatches.size()) {
242
+ return false;
243
+ }
244
+
245
+ return true;
246
+ }
247
+
248
+ bool llama_kv_cache_unified_iswa_context::apply() {
249
+ assert(!llama_memory_status_is_fail(status));
250
+
251
+ bool res = true;
252
+
253
+ res = res & ctx_base->apply();
254
+ res = res & ctx_swa ->apply();
255
+
256
+ return res;
257
+ }
258
+
259
+ llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
260
+ return status;
261
+ }
262
+
263
+ const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
264
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
265
+
266
+ return ubatches[i_next];
267
+ }
268
+
269
+ const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
270
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
271
+
272
+ return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
273
+ }
274
+
275
+ const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
276
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
277
+
278
+ return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
279
+ }