whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -7,6 +7,7 @@
7
7
  #include <cassert>
8
8
  #include <vector>
9
9
  #include <set>
10
+ #include <map>
10
11
 
11
12
  // meta information about KV cells that can be part of multiple sequences at the same time
12
13
  // TODO: add unit tests
@@ -23,7 +24,7 @@ public:
23
24
 
24
25
  used.clear();
25
26
 
26
- for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
27
+ for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
27
28
  seq_pos[s].clear();
28
29
  }
29
30
  }
@@ -68,12 +69,6 @@ public:
68
69
  // the index of the last cell that is used + 1
69
70
  // return 0 if no cells are used
70
71
  uint32_t used_max_p1() const {
71
- #if 0
72
- if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
73
- if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
74
- if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
75
- #endif
76
-
77
72
  return used.empty() ? 0 : *used.rbegin() + 1;
78
73
  }
79
74
 
@@ -86,6 +81,9 @@ public:
86
81
  assert(isrc < pos.size());
87
82
  assert(idst < pos.size());
88
83
 
84
+ assert(pos[idst] == -1);
85
+ assert(pos[isrc] != -1);
86
+
89
87
  pos [idst] = pos [isrc];
90
88
  shift[idst] = shift[isrc];
91
89
  seq [idst] = seq [isrc];
@@ -144,6 +142,20 @@ public:
144
142
  }
145
143
  }
146
144
 
145
+ // clear a non-empty cell
146
+ void rm(uint32_t i) {
147
+ assert(i < pos.size());
148
+ assert(pos[i] != -1);
149
+
150
+ seq_pos_rm(i);
151
+ seq[i].reset();
152
+
153
+ pos[i] = -1;
154
+ shift[i] = 0;
155
+
156
+ used.erase(i);
157
+ }
158
+
147
159
  // note: call only if the cell has seq_id
148
160
  // return true if the cell becomes empty
149
161
  bool seq_rm(uint32_t i, llama_seq_id seq_id) {
@@ -153,10 +165,11 @@ public:
153
165
  assert(seq_id >= 0);
154
166
 
155
167
  seq[i].reset(seq_id);
156
- seq_pos[seq_id].erase(pos[i]);
168
+ seq_pos_dec(seq_id, pos[i]);
157
169
 
158
170
  if (seq[i].none()) {
159
171
  pos[i] = -1;
172
+ shift[i] = 0;
160
173
 
161
174
  used.erase(i);
162
175
 
@@ -175,7 +188,7 @@ public:
175
188
  seq[i].reset();
176
189
 
177
190
  seq[i].set(seq_id);
178
- seq_pos[seq_id].insert(pos[i]);
191
+ seq_pos_inc(seq_id, pos[i]);
179
192
 
180
193
  return false;
181
194
  }
@@ -185,6 +198,7 @@ public:
185
198
  seq[i].reset();
186
199
 
187
200
  pos[i] = -1;
201
+ shift[i] = 0;
188
202
 
189
203
  used.erase(i);
190
204
 
@@ -196,6 +210,15 @@ public:
196
210
  return false;
197
211
  }
198
212
 
213
+ // number of different sequences in the cell
214
+ int seq_count(uint32_t i) const {
215
+ assert(i < pos.size());
216
+ assert(pos[i] != -1);
217
+
218
+ return seq[i].count();
219
+ }
220
+
221
+ // check if the cell contains seq_id
199
222
  bool seq_has(uint32_t i, llama_seq_id seq_id) const {
200
223
  assert(i < pos.size());
201
224
  assert(seq_id >= 0);
@@ -210,33 +233,51 @@ public:
210
233
  assert(!seq[i].test(seq_id));
211
234
 
212
235
  seq[i].set(seq_id);
213
- seq_pos[seq_id].insert(pos[i]);
236
+ seq_pos_inc(seq_id, pos[i]);
237
+ }
238
+
239
+ // return the sequence id of this cell
240
+ // note: call only for cells with exactly one sequence
241
+ llama_seq_id seq_get(uint32_t i) const {
242
+ assert(seq[i].count() == 1);
243
+
244
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
245
+ if (seq[i].test(s)) {
246
+ return s;
247
+ }
248
+ }
249
+
250
+ return -1;
214
251
  }
215
252
 
216
253
  // the minimum position of sequence seq_id currently present in any of the cells
217
254
  // return -1 if the sequence is not present
218
255
  llama_pos seq_pos_min(llama_seq_id seq_id) const {
219
256
  assert(seq_id >= 0);
220
- assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
257
+ assert(seq_id < LLAMA_MAX_SEQ);
221
258
 
222
259
  if (seq_pos[seq_id].empty()) {
223
260
  return -1;
224
261
  }
225
262
 
226
- return *seq_pos[seq_id].begin();
263
+ assert(seq_pos[seq_id].begin()->second > 0);
264
+
265
+ return seq_pos[seq_id].begin()->first;
227
266
  }
228
267
 
229
268
  // the maximum position of sequence seq_id currently present in any of the cells
230
269
  // return -1 if the sequence is not present
231
270
  llama_pos seq_pos_max(llama_seq_id seq_id) const {
232
271
  assert(seq_id >= 0);
233
- assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
272
+ assert(seq_id < LLAMA_MAX_SEQ);
234
273
 
235
274
  if (seq_pos[seq_id].empty()) {
236
275
  return -1;
237
276
  }
238
277
 
239
- return *seq_pos[seq_id].rbegin();
278
+ assert(seq_pos[seq_id].rbegin()->second > 0);
279
+
280
+ return seq_pos[seq_id].rbegin()->first;
240
281
  }
241
282
 
242
283
  // note: call only if the cell is not empty
@@ -268,6 +309,7 @@ public:
268
309
  void pos_set(uint32_t i, llama_pos p) {
269
310
  assert(i < pos.size());
270
311
  assert(pos[i] == -1);
312
+ assert(seq[i].none());
271
313
 
272
314
  pos[i] = p;
273
315
 
@@ -286,21 +328,20 @@ public:
286
328
  pos[i] += d;
287
329
  shift[i] += d;
288
330
 
289
- seq_pos_add(i);
290
-
291
331
  has_shift = true;
292
332
 
293
333
  if (pos[i] < 0) {
294
- seq_pos_rm(i);
295
-
296
334
  seq[i].reset();
297
335
  pos[i] = -1;
336
+ shift[i] = 0;
298
337
 
299
338
  used.erase(i);
300
339
 
301
340
  return true;
302
341
  }
303
342
 
343
+ seq_pos_add(i);
344
+
304
345
  return false;
305
346
  }
306
347
 
@@ -348,31 +389,50 @@ private:
348
389
  //
349
390
  std::vector<llama_pos> shift;
350
391
 
351
- using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
392
+ using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
352
393
 
353
394
  // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
354
- std::vector<bits_t> seq;
395
+ std::vector<seq_set_t> seq;
355
396
 
356
- // the set seq_pos[s] tells us which positions are currently present for sequence s
397
+ // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
398
+ // if the position p is not present, seq_pos[s][p] is not set
357
399
  // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
358
- std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
400
+ //
401
+ // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
402
+ // - during performing a cache reuse via (rm + add)
403
+ // - some vision models have input embeddings with repeating positions
404
+ //
405
+ std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
359
406
 
360
407
  // helper functions for updating `seq_pos`, once cell at a time:
361
408
 
409
+ void seq_pos_dec(llama_seq_id s, llama_pos p) {
410
+ auto it = seq_pos[s].find(p);
411
+ assert(it != seq_pos[s].end());
412
+
413
+ if (--it->second == 0) {
414
+ seq_pos[s].erase(it);
415
+ }
416
+ }
417
+
418
+ void seq_pos_inc(llama_seq_id s, llama_pos p) {
419
+ seq_pos[s][p]++;
420
+ }
421
+
362
422
  // remove cell i
363
423
  void seq_pos_rm(uint32_t i) {
364
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
424
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
365
425
  if (seq[i].test(s)) {
366
- seq_pos[s].erase(pos[i]);
426
+ seq_pos_dec(s, pos[i]);
367
427
  }
368
428
  }
369
429
  }
370
430
 
371
431
  // add cell i
372
432
  void seq_pos_add(uint32_t i) {
373
- for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
433
+ for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
374
434
  if (seq[i].test(s)) {
375
- seq_pos[s].insert(pos[i]);
435
+ seq_pos_inc(s, pos[i]);
376
436
  }
377
437
  }
378
438
  }
@@ -0,0 +1,246 @@
1
+ #include "llama-memory-hybrid.h"
2
+
3
+ #include "llama-impl.h"
4
+ #include "llama-model.h"
5
+ #include "llama-context.h"
6
+
7
+ //
8
+ // llama_memory_hybrid
9
+ //
10
+
11
+ llama_memory_hybrid::llama_memory_hybrid(
12
+ const llama_model & model,
13
+ /* attn */
14
+ ggml_type type_k,
15
+ ggml_type type_v,
16
+ bool v_trans,
17
+ uint32_t kv_size,
18
+ uint32_t n_pad,
19
+ uint32_t n_swa,
20
+ llama_swa_type swa_type,
21
+ /* recurrent */
22
+ ggml_type type_r,
23
+ ggml_type type_s,
24
+ uint32_t rs_size,
25
+ /* common */
26
+ uint32_t n_seq_max,
27
+ bool offload,
28
+ /* layer filters */
29
+ layer_filter_cb && filter_attn,
30
+ layer_filter_cb && filter_recr) :
31
+ hparams(model.hparams),
32
+ mem_attn(new llama_kv_cache_unified(
33
+ model,
34
+ filter_attn == nullptr ?
35
+ [&](int32_t il) { return !hparams.is_recurrent(il); }
36
+ : filter_attn,
37
+ type_k,
38
+ type_v,
39
+ v_trans,
40
+ offload,
41
+ kv_size,
42
+ n_seq_max,
43
+ n_pad,
44
+ n_swa,
45
+ swa_type
46
+ )),
47
+ mem_recr(new llama_memory_recurrent(
48
+ model,
49
+ filter_recr == nullptr ?
50
+ [&](int32_t il) { return hparams.is_recurrent(il); }
51
+ : filter_recr,
52
+ type_r,
53
+ type_s,
54
+ offload,
55
+ rs_size,
56
+ n_seq_max
57
+ )) {}
58
+
59
+ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
60
+ do {
61
+ balloc.split_reset();
62
+
63
+ // follow the recurrent pattern for creating the ubatch splits
64
+ std::vector<llama_ubatch> ubatches;
65
+
66
+ while (true) {
67
+ llama_ubatch ubatch;
68
+
69
+ if (embd_all) {
70
+ // if all tokens are output, split by sequence
71
+ ubatch = balloc.split_seq(n_ubatch);
72
+ } else {
73
+ ubatch = balloc.split_equal(n_ubatch);
74
+ }
75
+
76
+ if (ubatch.n_tokens == 0) {
77
+ break;
78
+ }
79
+
80
+ ubatches.push_back(std::move(ubatch)); // NOLINT
81
+ }
82
+
83
+ // prepare the recurrent batches first
84
+ if (!mem_recr->prepare(ubatches)) {
85
+ // TODO: will the recurrent cache be in an undefined context at this point?
86
+ LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
87
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
88
+ }
89
+
90
+ // prepare the attention cache
91
+ auto heads_attn = mem_attn->prepare(ubatches);
92
+ if (heads_attn.empty()) {
93
+ LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
94
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
95
+ }
96
+
97
+ return std::make_unique<llama_memory_hybrid_context>(
98
+ this, std::move(heads_attn), std::move(ubatches));
99
+ } while(false);
100
+
101
+ return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
102
+ }
103
+
104
+ llama_memory_context_ptr llama_memory_hybrid::init_full() {
105
+ return std::make_unique<llama_memory_hybrid_context>(this);
106
+ }
107
+
108
+ llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
109
+ return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
110
+ }
111
+
112
+ bool llama_memory_hybrid::get_can_shift() const {
113
+ // Shifting is trivially supported for recurrent
114
+ return mem_attn->get_can_shift();
115
+ }
116
+
117
+ void llama_memory_hybrid::clear(bool data) {
118
+ mem_attn->clear(data);
119
+ mem_recr->clear(data);
120
+ }
121
+
122
+ bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
123
+ // Try removing from the recurrent cache first since it may fail. If it does
124
+ // fail, the cache will not have been mutated.
125
+ if (!mem_recr->seq_rm(seq_id, p0, p1)) {
126
+ return false;
127
+ }
128
+ return mem_attn->seq_rm(seq_id, p0, p1);
129
+ }
130
+
131
+ void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
132
+ mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
133
+ mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
134
+ }
135
+
136
+ void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
137
+ mem_attn->seq_keep(seq_id);
138
+ mem_recr->seq_keep(seq_id);
139
+ }
140
+
141
+ void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
142
+ mem_attn->seq_add(seq_id, p0, p1, shift);
143
+ mem_recr->seq_add(seq_id, p0, p1, shift);
144
+ }
145
+
146
+ void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
147
+ mem_attn->seq_div(seq_id, p0, p1, d);
148
+ mem_recr->seq_div(seq_id, p0, p1, d);
149
+ }
150
+
151
+ llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
152
+ // the min of the total cache is the max of the two caches' min values
153
+ return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
154
+ }
155
+
156
+ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
157
+ // the max of the total cache is the min of the two caches' max values
158
+ return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
159
+ }
160
+
161
+ void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
162
+ mem_attn->state_write(io, seq_id);
163
+ mem_recr->state_write(io, seq_id);
164
+ }
165
+
166
+ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
167
+ mem_attn->state_read(io, seq_id);
168
+ mem_recr->state_read(io, seq_id);
169
+ }
170
+
171
+ llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
172
+ return mem_attn.get();
173
+ }
174
+
175
+ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
176
+ return mem_recr.get();
177
+ }
178
+
179
+ llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
180
+
181
+ llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
182
+ ctx_attn(mem->get_mem_attn()->init_full()),
183
+ ctx_recr(mem->get_mem_recr()->init_full()),
184
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
185
+ }
186
+
187
+ llama_memory_hybrid_context::llama_memory_hybrid_context(
188
+ llama_memory_hybrid * mem,
189
+ llama_context * lctx,
190
+ bool optimize) :
191
+ ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
192
+ ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
193
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
194
+ }
195
+
196
+ llama_memory_hybrid_context::llama_memory_hybrid_context(
197
+ llama_memory_hybrid * mem,
198
+ std::vector<uint32_t> heads_attn,
199
+ std::vector<llama_ubatch> ubatches) :
200
+ ubatches(std::move(ubatches)),
201
+ // note: here we copy the ubatches. not sure if this is ideal
202
+ ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
203
+ ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
204
+ status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
205
+ }
206
+
207
+ bool llama_memory_hybrid_context::next() {
208
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
209
+
210
+ ctx_attn->next();
211
+ ctx_recr->next();
212
+
213
+ if (++i_next >= ubatches.size()) {
214
+ return false;
215
+ }
216
+
217
+ return true;
218
+ }
219
+
220
+ bool llama_memory_hybrid_context::apply() {
221
+ assert(!llama_memory_status_is_fail(status));
222
+
223
+ bool res = true;
224
+
225
+ res = res & ctx_attn->apply();
226
+ res = res & ctx_recr->apply();
227
+
228
+ return res;
229
+ }
230
+
231
+ llama_memory_status llama_memory_hybrid_context::get_status() const {
232
+ return status;
233
+ }
234
+
235
+ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
236
+ assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
237
+ return ubatches[i_next];
238
+ }
239
+
240
+ const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
241
+ return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
242
+ }
243
+
244
+ const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
245
+ return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
246
+ }
@@ -0,0 +1,138 @@
1
+ #pragma once
2
+
3
+ #include "llama-batch.h"
4
+ #include "llama-graph.h"
5
+ #include "llama-kv-cache-unified.h"
6
+ #include "llama-memory.h"
7
+ #include "llama-memory-recurrent.h"
8
+
9
+ #include <memory>
10
+ #include <vector>
11
+
12
+ //
13
+ // llama_memory_hybrid
14
+ //
15
+
16
+ // utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
17
+ // support models where each layer may be either attention-based or recurrent
18
+
19
+ class llama_memory_hybrid : public llama_memory_i {
20
+ public:
21
+
22
+ // this callback is used to filter out layers that should not be included in the cache
23
+ using layer_filter_cb = std::function<bool(int32_t il)>;
24
+
25
+ llama_memory_hybrid(
26
+ const llama_model & model,
27
+ /* attn */
28
+ ggml_type type_k,
29
+ ggml_type type_v,
30
+ bool v_trans,
31
+ uint32_t kv_size,
32
+ uint32_t n_pad,
33
+ uint32_t n_swa,
34
+ llama_swa_type swa_type,
35
+ /* recurrent */
36
+ ggml_type type_r,
37
+ ggml_type type_s,
38
+ uint32_t rs_size,
39
+ /* common */
40
+ uint32_t n_seq_max,
41
+ bool offload,
42
+ /* layer filters */
43
+ layer_filter_cb && filter_attn = nullptr,
44
+ layer_filter_cb && filter_recr = nullptr);
45
+
46
+ ~llama_memory_hybrid() = default;
47
+
48
+ //
49
+ // llama_memory_i
50
+ //
51
+
52
+ llama_memory_context_ptr init_batch(
53
+ llama_batch_allocr & balloc,
54
+ uint32_t n_ubatch,
55
+ bool embd_all) override;
56
+
57
+ llama_memory_context_ptr init_full() override;
58
+
59
+ llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
60
+
61
+ bool get_can_shift() const override;
62
+
63
+ void clear(bool data) override;
64
+
65
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
66
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
67
+ void seq_keep(llama_seq_id seq_id) override;
68
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
69
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
70
+
71
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
72
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
73
+
74
+ // state write/load
75
+
76
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
77
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
78
+
79
+ //
80
+ // llama_memory_hybrid specific API
81
+ //
82
+
83
+ llama_kv_cache_unified * get_mem_attn() const;
84
+ llama_memory_recurrent * get_mem_recr() const;
85
+
86
+ private:
87
+ const llama_hparams & hparams;
88
+
89
+ const std::unique_ptr<llama_kv_cache_unified> mem_attn;
90
+ const std::unique_ptr<llama_memory_recurrent> mem_recr;
91
+ };
92
+
93
+ class llama_memory_hybrid_context : public llama_memory_context_i {
94
+ public:
95
+ // init failure
96
+ explicit llama_memory_hybrid_context(llama_memory_status status);
97
+
98
+ // init full
99
+ explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
100
+
101
+ // init update
102
+ explicit llama_memory_hybrid_context(
103
+ llama_memory_hybrid * mem,
104
+ llama_context * lctx,
105
+ bool optimize);
106
+
107
+ // init success
108
+ llama_memory_hybrid_context(
109
+ llama_memory_hybrid * mem,
110
+ std::vector<uint32_t> heads_attn,
111
+ std::vector<llama_ubatch> ubatches);
112
+
113
+ ~llama_memory_hybrid_context() = default;
114
+
115
+ bool next() override;
116
+ bool apply() override;
117
+
118
+ llama_memory_status get_status() const override;
119
+ const llama_ubatch & get_ubatch() const override;
120
+
121
+ //
122
+ // llama_memory_hybrid_context
123
+ //
124
+
125
+ const llama_kv_cache_unified_context * get_attn() const;
126
+ const llama_memory_recurrent_context * get_recr() const;
127
+
128
+ private:
129
+ // the index of the next ubatch to process
130
+ size_t i_next = 0;
131
+
132
+ std::vector<llama_ubatch> ubatches;
133
+
134
+ const llama_memory_context_ptr ctx_attn;
135
+ const llama_memory_context_ptr ctx_recr;
136
+
137
+ const llama_memory_status status;
138
+ };