whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -1,6 +1,7 @@
1
1
  #pragma once
2
2
 
3
3
  #include "llama-arch.h"
4
+ #include "llama-batch.h"
4
5
  #include "llama-hparams.h"
5
6
  #include "llama-adapter.h"
6
7
 
@@ -14,13 +15,14 @@ struct ggml_cgraph;
14
15
  struct ggml_context;
15
16
  struct ggml_tensor;
16
17
 
17
- struct llama_ubatch;
18
18
  struct llama_cparams;
19
19
 
20
- class llama_memory_i;
21
- class llama_kv_cache_unified;
22
- class llama_kv_cache_unified_iswa;
23
- class llama_kv_cache_recurrent;
20
+ struct llama_memory_context_i;
21
+
22
+ class llama_kv_cache_context;
23
+ class llama_kv_cache_iswa_context;
24
+ class llama_memory_recurrent_context;
25
+ class llama_memory_hybrid_context;
24
26
 
25
27
  // certain models (typically multi-modal) can produce different types of graphs
26
28
  enum llm_graph_type {
@@ -35,6 +37,9 @@ enum llm_ffn_op_type {
35
37
  LLM_FFN_RELU,
36
38
  LLM_FFN_RELU_SQR,
37
39
  LLM_FFN_SWIGLU,
40
+ LLM_FFN_GEGLU,
41
+ LLM_FFN_REGLU,
42
+ LLM_FFN_SWIGLU_OAI_MOE,
38
43
  };
39
44
 
40
45
  enum llm_ffn_gate_type {
@@ -65,20 +70,38 @@ struct llama_cross {
65
70
  std::vector<std::set<llama_seq_id>> seq_ids_enc;
66
71
  };
67
72
 
73
+ struct llm_graph_params;
74
+
68
75
  //
69
76
  // llm_graph_input
70
77
  //
71
78
 
72
79
  class llm_graph_input_i {
73
80
  public:
81
+ llm_graph_input_i() {
82
+ const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
83
+ debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
84
+ }
85
+
74
86
  virtual ~llm_graph_input_i() = default;
75
87
 
76
88
  virtual void set_input(const llama_ubatch * ubatch) = 0;
89
+
90
+ // return true if the resulting input tensors using the provided graph parameters would be
91
+ // the same as the previous input tensors that we have currently stored in the object
92
+ virtual bool can_reuse(const llm_graph_params & params) {
93
+ // returning false here by default will prevent from reusing the graph if the check
94
+ // for the input type has not been implemented yet
95
+ GGML_UNUSED(params);
96
+ return false;
97
+ }
98
+ protected:
99
+ // env: LLAMA_GRAPH_INPUT_DEBUG
100
+ int debug = 0;
77
101
  };
78
102
 
79
103
  using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
80
104
 
81
-
82
105
  class llm_graph_input_embd : public llm_graph_input_i {
83
106
  public:
84
107
  llm_graph_input_embd() = default;
@@ -86,20 +109,24 @@ public:
86
109
 
87
110
  void set_input(const llama_ubatch * ubatch) override;
88
111
 
112
+ bool can_reuse(const llm_graph_params & params) override;
113
+
89
114
  ggml_tensor * tokens = nullptr; // I32 [n_batch]
90
115
  ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
91
116
  };
92
117
 
93
118
  class llm_graph_input_pos : public llm_graph_input_i {
94
119
  public:
95
- llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
120
+ llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
96
121
  virtual ~llm_graph_input_pos() = default;
97
122
 
98
123
  void set_input(const llama_ubatch * ubatch) override;
99
124
 
125
+ bool can_reuse(const llm_graph_params & params) override;
126
+
100
127
  ggml_tensor * pos = nullptr; // I32 [n_batch]
101
128
 
102
- const int64_t n_pos_per_embd = 1;
129
+ const uint32_t n_pos_per_embd = 1;
103
130
  };
104
131
 
105
132
  // temperature tuning, used by llama4
@@ -126,22 +153,23 @@ public:
126
153
 
127
154
  ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
128
155
 
129
- const llama_hparams & hparams;
156
+ const llama_hparams hparams;
130
157
  };
131
158
 
132
159
  class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
133
160
  public:
134
161
  llm_graph_input_pos_bucket_kv(
135
162
  const llama_hparams & hparams,
136
- const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
163
+ const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
137
164
  virtual ~llm_graph_input_pos_bucket_kv() = default;
138
165
 
139
166
  void set_input(const llama_ubatch * ubatch) override;
140
167
 
141
168
  ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
142
169
 
143
- const llama_hparams & hparams;
144
- const llama_kv_cache_unified * kv_self;
170
+ const llama_hparams hparams;
171
+
172
+ const llama_kv_cache_context * mctx;
145
173
  };
146
174
 
147
175
  class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -149,17 +177,19 @@ public:
149
177
  llm_graph_input_out_ids(
150
178
  const llama_hparams & hparams,
151
179
  const llama_cparams & cparams,
152
- int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
180
+ uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
153
181
  virtual ~llm_graph_input_out_ids() = default;
154
182
 
155
183
  void set_input(const llama_ubatch * ubatch) override;
156
184
 
185
+ bool can_reuse(const llm_graph_params & params) override;
186
+
157
187
  ggml_tensor * out_ids; // I32 [n_outputs]
158
188
 
159
- const llama_hparams & hparams;
160
- const llama_cparams & cparams;
189
+ const llama_hparams hparams;
190
+ const llama_cparams cparams;
161
191
 
162
- const int32_t n_outputs;
192
+ const uint32_t n_outputs;
163
193
  };
164
194
 
165
195
  class llm_graph_input_mean : public llm_graph_input_i {
@@ -171,43 +201,37 @@ public:
171
201
 
172
202
  ggml_tensor * mean; // F32 [n_batch, n_batch]
173
203
 
174
- const llama_cparams & cparams;
204
+ const llama_cparams cparams;
175
205
  };
176
206
 
177
207
  class llm_graph_input_cls : public llm_graph_input_i {
178
208
  public:
179
- llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
209
+ llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
180
210
  virtual ~llm_graph_input_cls() = default;
181
211
 
182
212
  void set_input(const llama_ubatch * ubatch) override;
183
213
 
184
214
  ggml_tensor * cls; // I32 [n_batch]
185
215
 
186
- const llama_cparams & cparams;
216
+ const llama_cparams cparams;
217
+ const llm_arch arch;
187
218
  };
188
219
 
189
- class llm_graph_input_s_copy : public llm_graph_input_i {
220
+ class llm_graph_input_rs : public llm_graph_input_i {
190
221
  public:
191
- llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
192
- virtual ~llm_graph_input_s_copy() = default;
222
+ llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
223
+ virtual ~llm_graph_input_rs() = default;
193
224
 
194
225
  void set_input(const llama_ubatch * ubatch) override;
195
226
 
196
- ggml_tensor * s_copy; // I32 [kv_size]
197
-
198
- const llama_kv_cache_recurrent * kv_self;
199
- };
200
-
201
- class llm_graph_input_s_mask : public llm_graph_input_i {
202
- public:
203
- llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
204
- virtual ~llm_graph_input_s_mask() = default;
227
+ ggml_tensor * s_copy; // I32 [n_rs]
205
228
 
206
- void set_input(const llama_ubatch * ubatch) override;
229
+ // views of s_copy, computed once per graph
230
+ // and shared across layers which use build_rs
231
+ ggml_tensor * s_copy_main; // I32 [n_seqs]
232
+ ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
207
233
 
208
- ggml_tensor * s_mask; // F32 [1, n_kv]
209
-
210
- const llama_kv_cache_recurrent * kv_self;
234
+ const llama_memory_recurrent_context * mctx;
211
235
  };
212
236
 
213
237
  class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -235,64 +259,87 @@ public:
235
259
 
236
260
  ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
237
261
 
238
- ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
239
- ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
262
+ ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
263
+ ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
240
264
 
241
- const llama_hparams & hparams;
242
- const llama_cparams & cparams;
265
+ const llama_hparams hparams;
266
+ const llama_cparams cparams;
243
267
  };
244
268
 
245
- class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
269
+ class llm_graph_input_attn_kv : public llm_graph_input_i {
246
270
  public:
247
- llm_graph_input_attn_kv_unified(
271
+ llm_graph_input_attn_kv(
248
272
  const llama_hparams & hparams,
249
273
  const llama_cparams & cparams,
250
- const llama_kv_cache_unified * kv_self) :
274
+ const llama_kv_cache_context * mctx) :
251
275
  hparams(hparams),
252
276
  cparams(cparams),
253
- kv_self(kv_self) {
277
+ mctx(mctx) {
254
278
  }
255
- ~llm_graph_input_attn_kv_unified() = default;
279
+ ~llm_graph_input_attn_kv() = default;
256
280
 
257
281
  void set_input(const llama_ubatch * ubatch) override;
258
282
 
283
+ bool can_reuse(const llm_graph_params & params) override;
284
+
285
+ ggml_tensor * get_k_idxs() const { return self_k_idxs; }
286
+ ggml_tensor * get_v_idxs() const { return self_v_idxs; }
287
+
259
288
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
260
289
 
261
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
262
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
290
+ ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
291
+ ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
263
292
 
264
- const llama_hparams & hparams;
265
- const llama_cparams & cparams;
293
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
294
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
266
295
 
267
- const llama_kv_cache_unified * kv_self;
296
+ // note: these have to be copies because in order to be able to reuse a graph, its inputs
297
+ // need to carry these parameters with them. otherwise, they can point to freed
298
+ // llm_graph_params from a previous batch, causing stack-use-after-return
299
+ const llama_hparams hparams;
300
+ const llama_cparams cparams;
301
+
302
+ const llama_kv_cache_context * mctx;
268
303
  };
269
304
 
270
- class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
305
+ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
271
306
  public:
272
- llm_graph_input_attn_kv_unified_iswa(
307
+ llm_graph_input_attn_kv_iswa(
273
308
  const llama_hparams & hparams,
274
309
  const llama_cparams & cparams,
275
- const llama_kv_cache_unified_iswa * kv_self) :
310
+ const llama_kv_cache_iswa_context * mctx) :
276
311
  hparams(hparams),
277
312
  cparams(cparams),
278
- kv_self(kv_self) {
313
+ mctx(mctx) {
279
314
  }
280
- ~llm_graph_input_attn_kv_unified_iswa() = default;
315
+ ~llm_graph_input_attn_kv_iswa() = default;
281
316
 
282
317
  void set_input(const llama_ubatch * ubatch) override;
283
318
 
319
+ bool can_reuse(const llm_graph_params & params) override;
320
+
321
+ ggml_tensor * get_k_idxs() const { return self_k_idxs; }
322
+ ggml_tensor * get_v_idxs() const { return self_v_idxs; }
323
+ ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
324
+ ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
325
+
284
326
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
285
327
  ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
286
328
 
287
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
288
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
289
- ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
290
- ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
329
+ ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
330
+ ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
331
+ ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
332
+ ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
291
333
 
292
- const llama_hparams & hparams;
293
- const llama_cparams & cparams;
334
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
335
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
336
+ ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
337
+ ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
338
+
339
+ const llama_hparams hparams;
340
+ const llama_cparams cparams;
294
341
 
295
- const llama_kv_cache_unified_iswa * kv_self;
342
+ const llama_kv_cache_iswa_context * mctx;
296
343
  };
297
344
 
298
345
  class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -304,12 +351,34 @@ public:
304
351
 
305
352
  ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
306
353
 
307
- ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
308
- ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
354
+ ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
355
+ ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
309
356
 
310
357
  const llama_cross * cross = nullptr;
311
358
  };
312
359
 
360
+ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
361
+ public:
362
+ llm_graph_input_mem_hybrid(
363
+ std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
364
+ std::unique_ptr<llm_graph_input_rs> inp_rs,
365
+ const llama_memory_hybrid_context * mctx) :
366
+ inp_attn(std::move(inp_attn)),
367
+ inp_rs(std::move(inp_rs)),
368
+ mctx(mctx) { }
369
+ virtual ~llm_graph_input_mem_hybrid() = default;
370
+
371
+ void set_input(const llama_ubatch * ubatch) override;
372
+
373
+ std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
374
+ std::unique_ptr<llm_graph_input_rs> inp_rs;
375
+
376
+ llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
377
+ llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
378
+
379
+ const llama_memory_hybrid_context * mctx;
380
+ };
381
+
313
382
  //
314
383
  // llm_graph_result
315
384
  //
@@ -320,40 +389,110 @@ public:
320
389
  // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
321
390
  // these are used by the llama_context to extact the relevant data, based on the compute parameters
322
391
 
323
- class llm_graph_result_i {
324
- public:
325
- virtual ~llm_graph_result_i() = default;
392
+ // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
393
+ using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
326
394
 
327
- virtual ggml_tensor * get_tokens() = 0;
328
- virtual ggml_tensor * get_logits() = 0;
329
- virtual ggml_tensor * get_embd() = 0;
330
- virtual ggml_tensor * get_embd_pooled() = 0;
395
+ class llm_graph_result;
331
396
 
332
- virtual void set_inputs(const llama_ubatch * ubatch) = 0;
333
- };
397
+ struct llm_graph_params {
398
+ llm_arch arch = LLM_ARCH_UNKNOWN;
334
399
 
335
- using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
400
+ llama_hparams hparams;
401
+ llama_cparams cparams;
336
402
 
403
+ llama_ubatch ubatch; // note: intentionally make a copy
337
404
 
338
- class llm_graph_result : public llm_graph_result_i {
339
- public:
340
- virtual ~llm_graph_result() = default;
405
+ llm_graph_type gtype;
406
+
407
+ ggml_backend_sched_t sched;
408
+ ggml_backend_t backend_cpu;
341
409
 
342
- ggml_tensor * get_tokens() override { return t_tokens; }
343
- ggml_tensor * get_logits() override { return t_logits; }
344
- ggml_tensor * get_embd() override { return t_embd; }
345
- ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
410
+ const llama_adapter_cvec * cvec;
411
+ const llama_adapter_loras * loras;
412
+ const llama_memory_context_i * mctx;
413
+ const llama_cross * cross;
414
+
415
+ uint32_t n_outputs;
416
+
417
+ llm_graph_cb cb;
418
+
419
+ llm_graph_result * res;
420
+
421
+ // return true if the "other" params would result in a graph with the same topology as with the current params
422
+ // having the same topology allows us to reuse the graph in some cases
423
+ bool allow_reuse(const llm_graph_params & other) const {
424
+ // first check the ubatch
425
+ bool can_reuse_ubatch =
426
+ ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
427
+ ubatch.n_tokens == other.ubatch.n_tokens &&
428
+ ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
429
+ ubatch.n_seqs == other.ubatch.n_seqs &&
430
+ ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
431
+ (
432
+ (!ubatch.token && !other.ubatch.token) ||
433
+ (!ubatch.embd && !other.ubatch.embd)
434
+ );
435
+
436
+ // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
437
+ // the reason is because the set of attention streams would be different for different sequences
438
+ if (can_reuse_ubatch && ubatch.equal_seqs()) {
439
+ if (!ubatch.data) {
440
+ // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
441
+ // therefore we cannot perform the sequence id check. normally should never happen
442
+ can_reuse_ubatch = false;
443
+ } else {
444
+ for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
445
+ can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
446
+ }
447
+ }
448
+ }
346
449
 
347
- void set_inputs(const llama_ubatch * ubatch) override {
348
- for (auto & input : inputs) {
349
- input->set_input(ubatch);
450
+ if (!can_reuse_ubatch) {
451
+ return false;
350
452
  }
351
- }
352
453
 
353
- llm_graph_input_i * add_input(llm_graph_input_ptr input) {
354
- inputs.emplace_back(std::move(input));
355
- return inputs.back().get();
454
+ return
455
+ cparams.embeddings == other.cparams.embeddings &&
456
+ cparams.causal_attn == other.cparams.causal_attn &&
457
+ arch == other.arch &&
458
+ gtype == other.gtype &&
459
+ cvec == other.cvec &&
460
+ loras == other.loras &&
461
+ cross == other.cross &&
462
+ n_outputs == other.n_outputs;
356
463
  }
464
+ };
465
+
466
+ class llm_graph_result {
467
+ public:
468
+ llm_graph_result(int64_t max_nodes);
469
+
470
+ virtual ~llm_graph_result() = default;
471
+
472
+ ggml_tensor * get_tokens() const { return t_tokens; }
473
+ ggml_tensor * get_logits() const { return t_logits; }
474
+ ggml_tensor * get_embd() const { return t_embd; }
475
+ ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
476
+
477
+ ggml_cgraph * get_gf() const { return gf; }
478
+ ggml_context * get_ctx() const { return ctx_compute.get(); }
479
+
480
+ int64_t get_max_nodes() const;
481
+
482
+ void reset();
483
+
484
+ void set_inputs(const llama_ubatch * ubatch);
485
+
486
+ // try to update the existing graph result using the new graph parameters in order to reuse it
487
+ // this can only be done if we determine that the resulting graph using the new graph parameters
488
+ // would be identical to the existing graph. in that case, we simply have to update the memory
489
+ // contexts of the input tensors of the graph and we can reuse it for another computation
490
+ // return true if the graph was updated and can be reused
491
+ bool can_reuse(const llm_graph_params & params);
492
+
493
+ llm_graph_input_i * add_input(llm_graph_input_ptr input);
494
+
495
+ void set_params(const llm_graph_params & params);
357
496
 
358
497
  // important graph nodes
359
498
  ggml_tensor * t_tokens = nullptr;
@@ -362,36 +501,34 @@ public:
362
501
  ggml_tensor * t_embd_pooled = nullptr;
363
502
 
364
503
  std::vector<llm_graph_input_ptr> inputs;
365
- };
366
504
 
367
- //
368
- // llm_graph_context
369
- //
505
+ ggml_context_ptr ctx_compute;
370
506
 
371
- // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
372
- using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
507
+ // memory buffers used to evaluate the model
508
+ std::vector<uint8_t> buf_compute_meta;
373
509
 
374
- struct llm_graph_params {
375
- ggml_context * ctx;
510
+ ggml_cgraph * gf;
376
511
 
377
- const llm_arch arch;
512
+ int64_t max_nodes;
378
513
 
379
- const llama_hparams & hparams;
380
- const llama_cparams & cparams;
381
- const llama_ubatch & ubatch;
514
+ private:
515
+ // keep a copy of the previous graph parameters
516
+ // we will use this to determine whether the graph can be reused by comparing them with the new parameters
517
+ // note: these are updated after constructing the new graph
518
+ llm_graph_params params;
382
519
 
383
- ggml_backend_sched_t sched;
384
- ggml_backend_t backend_cpu;
520
+ // env: LLAMA_GRAPH_RESULT_DEBUG
521
+ int debug = 0;
522
+ };
385
523
 
386
- const llama_adapter_cvec * cvec;
387
- const llama_adapter_loras * loras;
388
- const llama_memory_i * memory;
389
- const llama_cross * cross;
524
+ using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
390
525
 
391
- int32_t n_outputs;
526
+ //
527
+ // llm_graph_context
528
+ //
392
529
 
393
- const llm_graph_cb & cb;
394
- };
530
+ // used in build_rs to properly order writes and avoid unnecessary copies
531
+ using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
395
532
 
396
533
  struct llm_graph_context {
397
534
  const llm_arch arch;
@@ -422,31 +559,31 @@ struct llm_graph_context {
422
559
  const float norm_eps;
423
560
  const float norm_rms_eps;
424
561
 
425
- const int32_t n_tokens;
426
- const int32_t n_outputs;
562
+ const int64_t n_tokens;
563
+ const int64_t n_outputs;
427
564
  const int32_t n_ctx_orig; // yarn
428
565
 
429
566
  const enum llama_pooling_type pooling_type;
430
567
  const enum llama_rope_type rope_type;
431
568
 
432
- ggml_context * ctx0 = nullptr;
433
-
434
569
  ggml_backend_sched_t sched;
435
570
 
436
571
  ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
437
572
 
438
- const llama_adapter_cvec * cvec;
439
- const llama_adapter_loras * loras;
440
- const llama_memory_i * memory;
441
- const llama_cross * cross;
573
+ const llama_adapter_cvec * cvec;
574
+ const llama_adapter_loras * loras;
575
+ const llama_memory_context_i * mctx;
576
+ const llama_cross * cross;
442
577
 
443
578
  const llm_graph_cb & cb_func;
444
579
 
445
- std::unique_ptr<llm_graph_result> res;
580
+ llm_graph_result * res;
446
581
 
447
- llm_graph_context(const llm_graph_params & params);
582
+ ggml_context * ctx0 = nullptr;
583
+ ggml_cgraph * gf = nullptr;
448
584
 
449
- int64_t n_pos_per_embd() const;
585
+ llm_graph_context(const llm_graph_params & params);
586
+ virtual ~llm_graph_context() = default;
450
587
 
451
588
  void cb(ggml_tensor * cur, const char * name, int il) const;
452
589
 
@@ -492,6 +629,7 @@ struct llm_graph_context {
492
629
  llm_ffn_gate_type type_gate,
493
630
  int il) const;
494
631
 
632
+ // build MoE FFN without bias tensors
495
633
  ggml_tensor * build_moe_ffn(
496
634
  ggml_tensor * cur,
497
635
  ggml_tensor * gate_inp,
@@ -506,7 +644,29 @@ struct llm_graph_context {
506
644
  bool scale_w,
507
645
  float w_scale,
508
646
  llama_expert_gating_func_type gating_op,
509
- int il) const;
647
+ int il,
648
+ ggml_tensor * probs_in = nullptr) const;
649
+
650
+ ggml_tensor * build_moe_ffn(
651
+ ggml_tensor * cur,
652
+ ggml_tensor * gate_inp,
653
+ ggml_tensor * gate_inp_b,
654
+ ggml_tensor * up_exps,
655
+ ggml_tensor * up_exps_b,
656
+ ggml_tensor * gate_exps,
657
+ ggml_tensor * gate_exps_b,
658
+ ggml_tensor * down_exps,
659
+ ggml_tensor * down_exps_b,
660
+ ggml_tensor * exp_probs_b,
661
+ int64_t n_expert,
662
+ int64_t n_expert_used,
663
+ llm_ffn_op_type type_op,
664
+ bool norm_w,
665
+ bool scale_w,
666
+ float w_scale,
667
+ llama_expert_gating_func_type gating_op,
668
+ int il,
669
+ ggml_tensor * probs_in = nullptr) const;
510
670
 
511
671
  //
512
672
  // inputs
@@ -518,8 +678,6 @@ struct llm_graph_context {
518
678
  ggml_tensor * build_inp_out_ids() const;
519
679
  ggml_tensor * build_inp_mean() const;
520
680
  ggml_tensor * build_inp_cls() const;
521
- ggml_tensor * build_inp_s_copy() const;
522
- ggml_tensor * build_inp_s_mask() const;
523
681
 
524
682
  ggml_tensor * build_inp_cross_embd() const;
525
683
  ggml_tensor * build_inp_pos_bucket_enc() const;
@@ -531,56 +689,58 @@ struct llm_graph_context {
531
689
  //
532
690
 
533
691
  ggml_tensor * build_attn_mha(
534
- ggml_cgraph * gf,
535
- ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
536
- ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
537
- ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
538
- ggml_tensor * kq_b,
539
- ggml_tensor * kq_mask,
540
- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
541
- float kq_scale) const;
692
+ ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
693
+ ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
694
+ ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
695
+ ggml_tensor * kq_b,
696
+ ggml_tensor * kq_mask,
697
+ ggml_tensor * sinks, // [n_head_q]
698
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
699
+ float kq_scale,
700
+ int il) const;
542
701
 
543
702
  llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
544
703
 
545
704
  ggml_tensor * build_attn(
546
705
  llm_graph_input_attn_no_cache * inp,
547
- ggml_cgraph * gf,
548
706
  ggml_tensor * wo,
549
707
  ggml_tensor * wo_b,
550
708
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
551
709
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
552
710
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
553
711
  ggml_tensor * kq_b,
712
+ ggml_tensor * sinks, // [n_head_q]
554
713
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
555
714
  float kq_scale,
556
715
  int il) const;
557
716
 
558
- llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
717
+ llm_graph_input_attn_kv * build_attn_inp_kv() const;
559
718
 
560
719
  ggml_tensor * build_attn(
561
- llm_graph_input_attn_kv_unified * inp,
562
- ggml_cgraph * gf,
720
+ llm_graph_input_attn_kv * inp,
563
721
  ggml_tensor * wo,
564
722
  ggml_tensor * wo_b,
565
723
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
566
724
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
567
725
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
568
726
  ggml_tensor * kq_b,
727
+ ggml_tensor * sinks, // [n_head_q]
569
728
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
570
729
  float kq_scale,
571
730
  int il) const;
572
731
 
573
- llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
732
+ llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
574
733
 
734
+ // note: if k_cur or v_cur are not provided, they will not be stored in the memory
575
735
  ggml_tensor * build_attn(
576
- llm_graph_input_attn_kv_unified_iswa * inp,
577
- ggml_cgraph * gf,
736
+ llm_graph_input_attn_kv_iswa * inp,
578
737
  ggml_tensor * wo,
579
738
  ggml_tensor * wo_b,
580
739
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
581
- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
582
- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
740
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
741
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
583
742
  ggml_tensor * kq_b,
743
+ ggml_tensor * sinks, // [n_head_q]
584
744
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
585
745
  float kq_scale,
586
746
  int il) const;
@@ -589,13 +749,13 @@ struct llm_graph_context {
589
749
 
590
750
  ggml_tensor * build_attn(
591
751
  llm_graph_input_attn_cross * inp,
592
- ggml_cgraph * gf,
593
752
  ggml_tensor * wo,
594
753
  ggml_tensor * wo_b,
595
754
  ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
596
755
  ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
597
756
  ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
598
757
  ggml_tensor * kq_b,
758
+ ggml_tensor * sinks, // [n_head_q]
599
759
  ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
600
760
  float kq_scale,
601
761
  int il) const;
@@ -604,32 +764,52 @@ struct llm_graph_context {
604
764
  // recurrent
605
765
  //
606
766
 
607
- ggml_tensor * build_copy_mask_state(
608
- ggml_cgraph * gf,
609
- ggml_tensor * s,
610
- ggml_tensor * state_copy,
611
- ggml_tensor * state_mask,
612
- int32_t n_state,
613
- int32_t n_seqs) const;
767
+ // TODO: move this implementation to llama_memory_recurrent.
768
+ // this is analogous to llama_kv_cache::cpy_k / cpy_v
769
+ // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
770
+ // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
771
+ // `llama_memory_recurrent`
772
+ ggml_tensor * build_rs(
773
+ ggml_tensor * s,
774
+ ggml_tensor * state_copy_main,
775
+ ggml_tensor * state_copy_extra,
776
+ int32_t state_size,
777
+ int32_t n_seqs,
778
+ uint32_t n_rs,
779
+ uint32_t rs_head,
780
+ uint32_t rs_size,
781
+ int32_t rs_zero,
782
+ const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
783
+
784
+ llm_graph_input_rs * build_rs_inp() const;
785
+
786
+ ggml_tensor * build_rs(
787
+ llm_graph_input_rs * inp,
788
+ ggml_tensor * s,
789
+ int32_t state_size,
790
+ int32_t n_seqs,
791
+ const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
614
792
 
615
793
  ggml_tensor * build_rwkv_token_shift_load(
616
- ggml_cgraph * gf,
617
- ggml_tensor * state_copy,
618
- ggml_tensor * state_mask,
619
- const llama_ubatch & ubatch,
620
- int il) const;
794
+ llm_graph_input_rs * inp,
795
+ const llama_ubatch & ubatch,
796
+ int il) const;
621
797
 
622
798
  ggml_tensor * build_rwkv_token_shift_store(
623
799
  ggml_tensor * token_shift,
624
800
  const llama_ubatch & ubatch,
625
801
  int il) const;
802
+ //
803
+ // hybrid
804
+ //
805
+
806
+ llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
626
807
 
627
808
  //
628
809
  // pooling
629
810
  //
630
811
 
631
812
  void build_pooling(
632
- ggml_cgraph * gf,
633
813
  ggml_tensor * cls,
634
814
  ggml_tensor * cls_b,
635
815
  ggml_tensor * cls_out,