whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -3,7 +3,11 @@
3
3
  #include "llama-impl.h"
4
4
  #include "llama-batch.h"
5
5
  #include "llama-cparams.h"
6
+
6
7
  #include "llama-kv-cache.h"
8
+ #include "llama-kv-cache-iswa.h"
9
+ #include "llama-memory-hybrid.h"
10
+ #include "llama-memory-recurrent.h"
7
11
 
8
12
  #include <cassert>
9
13
  #include <cmath>
@@ -24,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
24
28
  }
25
29
  }
26
30
 
31
+ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
32
+ bool res = true;
33
+
34
+ res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
35
+ res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
36
+
37
+ return res;
38
+ }
39
+
27
40
  void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
28
41
  if (ubatch->pos && pos) {
29
42
  const int64_t n_tokens = ubatch->n_tokens;
@@ -46,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
46
59
  }
47
60
  }
48
61
 
62
+ bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
63
+ bool res = true;
64
+
65
+ res &= pos->ne[0] == params.ubatch.n_tokens;
66
+
67
+ return res;
68
+ }
69
+
49
70
  void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
50
71
  if (ubatch->pos && attn_scale) {
51
72
  const int64_t n_tokens = ubatch->n_tokens;
@@ -67,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
67
88
  const int64_t n_tokens = ubatch->n_tokens;
68
89
 
69
90
  GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
70
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
91
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
71
92
 
72
93
  int32_t * data = (int32_t *) pos_bucket->data;
73
94
 
@@ -83,182 +104,149 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
83
104
 
84
105
  void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
85
106
  if (pos_bucket) {
86
- kv_self->set_input_pos_bucket(pos_bucket, ubatch);
107
+ mctx->set_input_pos_bucket(pos_bucket, ubatch);
87
108
  }
88
109
  }
89
110
 
90
111
  void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
91
- if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
92
- //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
112
+ GGML_ASSERT(out_ids);
93
113
 
94
- if (!out_ids) {
95
- LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
96
- } else {
97
- const int64_t n_tokens = ubatch->n_tokens;
114
+ const int64_t n_tokens = ubatch->n_tokens;
98
115
 
99
- GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
100
- int32_t * data = (int32_t *) out_ids->data;
116
+ GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
117
+ int32_t * data = (int32_t *) out_ids->data;
101
118
 
102
- if (n_outputs == n_tokens) {
103
- for (int i = 0; i < n_tokens; ++i) {
104
- data[i] = i;
105
- }
106
- } else if (ubatch->output) {
107
- int32_t n_outputs = 0;
108
- for (int i = 0; i < n_tokens; ++i) {
109
- if (ubatch->output[i]) {
110
- data[n_outputs++] = i;
111
- }
112
- }
113
- // the graph needs to have been passed the correct number of outputs
114
- GGML_ASSERT(n_outputs == n_outputs);
115
- } else if (n_outputs == 1) {
116
- // only keep last output
117
- data[0] = n_tokens - 1;
118
- } else {
119
- GGML_ASSERT(n_outputs == 0);
120
- }
119
+ if (n_outputs == n_tokens) {
120
+ for (int i = 0; i < n_tokens; ++i) {
121
+ data[i] = i;
122
+ }
123
+
124
+ return;
125
+ }
126
+
127
+ GGML_ASSERT(ubatch->output);
128
+
129
+ int n_outputs = 0;
130
+
131
+ for (int i = 0; i < n_tokens; ++i) {
132
+ if (ubatch->output[i]) {
133
+ data[n_outputs++] = i;
121
134
  }
122
135
  }
123
136
  }
124
137
 
138
+ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
139
+ bool res = true;
140
+
141
+ res &= n_outputs == params.n_outputs;
142
+
143
+ return res;
144
+ }
145
+
125
146
  void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
126
147
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
127
148
  const int64_t n_tokens = ubatch->n_tokens;
128
149
  const int64_t n_seq_tokens = ubatch->n_seq_tokens;
129
- const int64_t n_seqs = ubatch->n_seqs;
150
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
130
151
 
131
152
  GGML_ASSERT(mean);
132
153
  GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
133
154
 
134
155
  float * data = (float *) mean->data;
135
- memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
136
-
137
- std::vector<uint64_t> sum(n_tokens, 0);
138
-
139
- for (int s = 0; s < n_seqs; ++s) {
140
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
156
+ memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
141
157
 
142
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
143
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
158
+ std::vector<uint64_t> sums(n_seqs_unq, 0);
159
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
160
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
161
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
162
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
144
163
 
145
- sum[seq_id] += ubatch->n_seq_tokens;
146
- }
147
-
148
- std::vector<float> div(n_tokens, 0.0f);
149
- for (int i = 0; i < n_tokens; ++i) {
150
- const uint64_t s = sum[i];
151
- if (s > 0) {
152
- div[i] = 1.0f/float(s);
164
+ sums[seq_idx] += ubatch->n_seq_tokens;
153
165
  }
154
166
  }
155
167
 
156
- for (int s = 0; s < n_seqs; ++s) {
157
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
158
-
159
- for (int i = 0; i < n_seq_tokens; ++i) {
160
- data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
168
+ std::vector<float> div(n_seqs_unq, 0.0f);
169
+ for (int s = 0; s < n_seqs_unq; ++s) {
170
+ const uint64_t sum = sums[s];
171
+ if (sum > 0) {
172
+ div[s] = 1.0f/float(sum);
161
173
  }
162
174
  }
163
- }
164
- }
165
-
166
- void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
167
- if (cparams.embeddings && (
168
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
169
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
170
- const int64_t n_tokens = ubatch->n_tokens;
171
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
172
- const int64_t n_seqs = ubatch->n_seqs;
173
175
 
174
- GGML_ASSERT(cls);
175
- GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
176
-
177
- uint32_t * data = (uint32_t *) cls->data;
178
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
179
-
180
- for (int s = 0; s < n_seqs; ++s) {
181
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
182
-
183
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
184
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
176
+ for (int i = 0; i < n_tokens; i += n_seq_tokens) {
177
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
178
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
179
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
185
180
 
186
- for (int i = 0; i < n_seq_tokens; ++i) {
187
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
188
-
189
- if (pos == 0) {
190
- data[seq_id] = s*n_seq_tokens + i;
181
+ for (int j = 0; j < n_seq_tokens; ++j) {
182
+ data[seq_idx*n_tokens + i + j] = div[seq_idx];
191
183
  }
192
184
  }
193
185
  }
194
186
  }
187
+ }
195
188
 
196
- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
197
- const int64_t n_tokens = ubatch->n_tokens;
198
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
199
- const int64_t n_seqs = ubatch->n_seqs;
189
+ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
190
+ const int64_t n_tokens = ubatch->n_tokens;
191
+ const int64_t n_seqs_unq = ubatch->n_seqs_unq;
200
192
 
193
+ if (cparams.embeddings && (
194
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
195
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
196
+ cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
197
+ )) {
201
198
  GGML_ASSERT(cls);
202
199
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
203
200
 
204
201
  uint32_t * data = (uint32_t *) cls->data;
205
- memset(cls->data, 0, n_tokens * ggml_element_size(cls));
202
+ memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
206
203
 
207
- std::vector<int> last_pos(n_tokens, -1);
208
- std::vector<int> last_row(n_tokens, -1);
204
+ std::vector<int> target_pos(n_seqs_unq, -1);
205
+ std::vector<int> target_row(n_seqs_unq, -1);
209
206
 
210
- for (int s = 0; s < n_seqs; ++s) {
211
- const llama_seq_id seq_id = ubatch->seq_id[s][0];
207
+ const bool last = (
208
+ cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
209
+ (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
210
+ );
212
211
 
213
- // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
214
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
215
-
216
- for (int i = 0; i < n_seq_tokens; ++i) {
217
- const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
218
-
219
- if (pos >= last_pos[seq_id]) {
220
- last_pos[seq_id] = pos;
221
- last_row[seq_id] = s*n_seq_tokens + i;
212
+ for (int i = 0; i < n_tokens; ++i) {
213
+ const llama_pos pos = ubatch->pos[i];
214
+
215
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
216
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
217
+ const int32_t seq_idx = ubatch->seq_idx[seq_id];
218
+
219
+ if (
220
+ (target_pos[seq_idx] == -1) ||
221
+ ( last && pos >= target_pos[seq_idx]) ||
222
+ (!last && pos < target_pos[seq_idx])
223
+ ) {
224
+ target_pos[seq_idx] = pos;
225
+ target_row[seq_idx] = i;
222
226
  }
223
227
  }
224
228
  }
225
229
 
226
- for (int i = 0; i < n_tokens; ++i) {
227
- if (last_row[i] >= 0) {
228
- data[i] = last_row[i];
230
+ for (int s = 0; s < n_seqs_unq; ++s) {
231
+ if (target_row[s] >= 0) {
232
+ data[s] = target_row[s];
229
233
  }
230
234
  }
231
235
  }
232
236
  }
233
237
 
234
- void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
238
+ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
235
239
  GGML_UNUSED(ubatch);
236
240
 
237
- const int64_t n_kv = kv_self->n;
241
+ const int64_t n_rs = mctx->get_n_rs();
238
242
 
239
243
  if (s_copy) {
240
244
  GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
241
245
  int32_t * data = (int32_t *) s_copy->data;
242
246
 
243
247
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
244
- for (uint32_t i = 0; i < n_kv; ++i) {
245
- data[i] = kv_self->s_copy(i);
246
- }
247
- }
248
- }
249
-
250
- void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
251
- GGML_UNUSED(ubatch);
252
-
253
- const int64_t n_kv = kv_self->n;
254
-
255
- if (s_mask) {
256
- GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
257
- float * data = (float *) s_mask->data;
258
-
259
- // clear unused states
260
- for (int i = 0; i < n_kv; ++i) {
261
- data[i] = kv_self->s_mask(i);
248
+ for (uint32_t i = 0; i < n_rs; ++i) {
249
+ data[i] = mctx->s_copy(i);
262
250
  }
263
251
  }
264
252
  }
@@ -273,142 +261,270 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
273
261
  }
274
262
  }
275
263
 
264
+ static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
265
+ LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
266
+ const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
267
+ (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
268
+ (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
269
+ (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
270
+ LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
271
+ LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
272
+ LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
273
+
274
+ LLAMA_LOG_DEBUG(" ");
275
+ for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
276
+ LLAMA_LOG_DEBUG("%2d", j);
277
+ }
278
+ LLAMA_LOG_DEBUG("\n");
279
+
280
+ for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
281
+ LLAMA_LOG_DEBUG(" %2d ", i);
282
+ for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
283
+ float val = data[i * n_kv + j];
284
+ if (val == -INFINITY) {
285
+ LLAMA_LOG_DEBUG(" ∞");
286
+ } else {
287
+ LLAMA_LOG_DEBUG(" 0");
288
+ }
289
+ }
290
+ LLAMA_LOG_DEBUG("\n");
291
+ }
292
+ }
293
+
276
294
  void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
277
- if (kq_mask) {
278
- if (cparams.causal_attn) {
279
- const int64_t n_kv = ubatch->n_tokens;
280
- const int64_t n_tokens = ubatch->n_tokens;
281
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
282
- const int64_t n_seqs = ubatch->n_seqs;
283
-
284
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
285
- float * data = (float *) kq_mask->data;
286
-
287
- for (int h = 0; h < 1; ++h) {
288
- for (int s1 = 0; s1 < n_seqs; ++s1) {
289
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
290
-
291
- for (int j = 0; j < n_seq_tokens; ++j) {
292
- const int32_t tj = s1*n_seq_tokens + j;
293
-
294
- for (int s0 = 0; s0 < n_seqs; ++s0) {
295
- for (int i = 0; i < n_seq_tokens; ++i) {
296
- const int32_t ti = s0*n_seq_tokens + i;
297
- float f = -INFINITY;
298
-
299
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
300
- if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
301
- if (hparams.use_alibi) {
302
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
303
- } else {
304
- f = 0.0f;
305
- }
306
- break;
307
- }
308
- }
309
-
310
- data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
311
- }
312
- }
295
+ const int64_t n_kv = ubatch->n_tokens;
296
+ const int64_t n_tokens = ubatch->n_tokens;
297
+
298
+ GGML_ASSERT(kq_mask);
299
+ GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
300
+
301
+ float * data = (float *) kq_mask->data;
302
+
303
+ // [TAG_NO_CACHE_ISWA]
304
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
305
+
306
+ for (int h = 0; h < 1; ++h) {
307
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
308
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
309
+
310
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
311
+ float f = -INFINITY;
312
+
313
+ for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
314
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
315
+
316
+ if (s0 != s1) {
317
+ continue; // skip different sequences
313
318
  }
314
- }
315
- }
316
- } else {
317
- const int64_t n_tokens = ubatch->n_tokens;
318
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
319
- const int64_t n_seqs = ubatch->n_seqs;
320
- const int64_t n_stride = ubatch->n_tokens;
321
-
322
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
323
-
324
- float * data = (float *) kq_mask->data;
325
-
326
- for (int h = 0; h < 1; ++h) {
327
- for (int s1 = 0; s1 < n_seqs; ++s1) {
328
- const llama_seq_id seq_id = ubatch->seq_id[s1][0];
329
-
330
- for (int j = 0; j < n_seq_tokens; ++j) {
331
- const int32_t tj = s1*n_seq_tokens + j;
332
-
333
- for (int s0 = 0; s0 < n_seqs; ++s0) {
334
- for (int i = 0; i < n_seq_tokens; ++i) {
335
- const int32_t ti = s0*n_seq_tokens + i;
336
- float f = -INFINITY;
337
-
338
- for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
339
- if (ubatch->seq_id[s0][s] == seq_id) {
340
- if (hparams.use_alibi) {
341
- f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
342
- } else {
343
- f = 0.0f;
344
- }
345
- break;
346
- }
347
- }
348
-
349
- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
350
- }
351
- }
352
-
353
- for (int i = n_tokens; i < n_stride; ++i) {
354
- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
355
- }
319
+
320
+ if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
321
+ continue; // skip future tokens for causal attention
322
+ }
323
+
324
+ // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
325
+ //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
326
+ // continue; // skip masked tokens for SWA
327
+ //}
328
+
329
+ // TODO: reimplement this like in llama_kv_cache_unified
330
+ if (hparams.use_alibi) {
331
+ f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
332
+ } else {
333
+ f = 0.0f;
356
334
  }
357
335
  }
336
+ data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
358
337
  }
359
338
  }
360
339
  }
340
+ if (debug) {
341
+ print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
342
+ }
361
343
  }
362
344
 
363
- void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
364
- if (self_kq_mask) {
365
- kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
366
- }
345
+ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
346
+ mctx->set_input_k_idxs(self_k_idxs, ubatch);
347
+ mctx->set_input_v_idxs(self_v_idxs, ubatch);
348
+
349
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
367
350
  }
368
351
 
369
- void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
370
- if (self_kq_mask) {
371
- kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
372
- }
352
+ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
353
+ const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
373
354
 
374
- if (self_kq_mask_swa) {
375
- kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
376
- }
355
+ this->mctx = mctx;
356
+
357
+ bool res = true;
358
+
359
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
360
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
361
+
362
+ res &= self_kq_mask->ne[0] == mctx->get_n_kv();
363
+ res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
364
+
365
+ return res;
366
+ }
367
+
368
+ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
369
+ mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
370
+ mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
371
+
372
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
373
+
374
+ mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
375
+ mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
376
+
377
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
378
+ }
379
+
380
+ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
381
+ const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
382
+
383
+ this->mctx = mctx;
384
+
385
+ bool res = true;
386
+
387
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
388
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
389
+
390
+ res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
391
+ //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
392
+
393
+ res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
394
+ res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
395
+
396
+ res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
397
+ res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
398
+
399
+ return res;
377
400
  }
378
401
 
379
402
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
380
- if (cross_kq_mask) {
381
- const int64_t n_enc = cross_kq_mask->ne[0];
382
- const int64_t n_tokens = ubatch->n_tokens;
403
+ GGML_ASSERT(cross_kq_mask);
383
404
 
384
- GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
385
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
405
+ const int64_t n_enc = cross_kq_mask->ne[0];
406
+ const int64_t n_tokens = ubatch->n_tokens;
386
407
 
387
- float * data = (float *) cross_kq_mask->data;
408
+ GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
409
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
388
410
 
389
- for (int h = 0; h < 1; ++h) {
390
- for (int j = 0; j < n_tokens; ++j) {
391
- for (int i = 0; i < n_enc; ++i) {
392
- float f = -INFINITY;
393
- for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
394
- const llama_seq_id seq_id = ubatch->seq_id[j][s];
395
- if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
396
- f = 0.0f;
397
- }
411
+ float * data = (float *) cross_kq_mask->data;
412
+
413
+ for (int h = 0; h < 1; ++h) {
414
+ for (int i = 0; i < n_tokens; ++i) {
415
+ for (int j = 0; j < n_enc; ++j) {
416
+ float f = -INFINITY;
417
+
418
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
419
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
420
+
421
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
422
+ f = 0.0f;
398
423
  }
399
- data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
400
424
  }
425
+
426
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
401
427
  }
428
+ }
402
429
 
403
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
404
- for (int j = 0; j < n_enc; ++j) {
405
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
406
- }
430
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
431
+ for (int j = 0; j < n_enc; ++j) {
432
+ data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
407
433
  }
408
434
  }
409
435
  }
410
436
  }
411
437
 
438
+ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
439
+ inp_attn->set_input(ubatch);
440
+ inp_rs->set_input(ubatch);
441
+ }
442
+
443
+ //
444
+ // llm_graph_result
445
+ //
446
+
447
+ llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
448
+ reset();
449
+
450
+ const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
451
+ debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
452
+ }
453
+
454
+ int64_t llm_graph_result::get_max_nodes() const {
455
+ return max_nodes;
456
+ }
457
+
458
+ void llm_graph_result::reset() {
459
+ t_tokens = nullptr;
460
+ t_logits = nullptr;
461
+ t_embd = nullptr;
462
+ t_embd_pooled = nullptr;
463
+
464
+ params = {};
465
+
466
+ inputs.clear();
467
+
468
+ buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
469
+
470
+ ggml_init_params params = {
471
+ /*.mem_size =*/ buf_compute_meta.size(),
472
+ /*.mem_buffer =*/ buf_compute_meta.data(),
473
+ /*.no_alloc =*/ true,
474
+ };
475
+
476
+ ctx_compute.reset(ggml_init(params));
477
+
478
+ gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
479
+ }
480
+
481
+ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
482
+ for (auto & input : inputs) {
483
+ input->set_input(ubatch);
484
+ }
485
+ }
486
+
487
+ bool llm_graph_result::can_reuse(const llm_graph_params & params) {
488
+ if (!this->params.allow_reuse(params)) {
489
+ if (debug > 1) {
490
+ LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
491
+ }
492
+
493
+ return false;
494
+ }
495
+
496
+ if (debug > 1) {
497
+ LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
498
+ }
499
+
500
+ bool res = true;
501
+
502
+ for (auto & input : inputs) {
503
+ const bool cur = input->can_reuse(params);
504
+
505
+ if (debug > 1) {
506
+ LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
507
+ }
508
+
509
+ res = res && cur;
510
+ }
511
+
512
+ if (debug > 0) {
513
+ LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
514
+ }
515
+
516
+ return res;
517
+ }
518
+
519
+ llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
520
+ inputs.emplace_back(std::move(input));
521
+ return inputs.back().get();
522
+ }
523
+
524
+ void llm_graph_result::set_params(const llm_graph_params & params) {
525
+ this->params = params;
526
+ }
527
+
412
528
  //
413
529
  // llm_graph_context
414
530
  //
@@ -443,21 +559,19 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
443
559
  n_ctx_orig (cparams.n_ctx_orig_yarn),
444
560
  pooling_type (cparams.pooling_type),
445
561
  rope_type (hparams.rope_type),
446
- ctx0 (params.ctx),
447
562
  sched (params.sched),
448
563
  backend_cpu (params.backend_cpu),
449
564
  cvec (params.cvec),
450
565
  loras (params.loras),
451
- memory (params.memory),
566
+ mctx (params.mctx),
452
567
  cross (params.cross),
453
568
  cb_func (params.cb),
454
- res (std::make_unique<llm_graph_result>()) {
569
+ res (params.res),
570
+ ctx0 (res->get_ctx()),
571
+ gf (res->get_gf()) {
572
+ res->set_params(params);
455
573
  }
456
574
 
457
- int64_t llm_graph_context::n_pos_per_embd() const {
458
- return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
459
- }
460
-
461
575
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
462
576
  if (cb_func) {
463
577
  cb_func(ubatch, cur, name, il);
@@ -617,12 +731,20 @@ ggml_tensor * llm_graph_context::build_ffn(
617
731
 
618
732
  switch (type_op) {
619
733
  case LLM_FFN_SILU:
620
- {
734
+ if (gate && type_gate == LLM_FFN_PAR) {
735
+ cur = ggml_swiglu_split(ctx0, cur, tmp);
736
+ cb(cur, "ffn_swiglu", il);
737
+ type_gate = LLM_FFN_SEQ;
738
+ } else {
621
739
  cur = ggml_silu(ctx0, cur);
622
740
  cb(cur, "ffn_silu", il);
623
741
  } break;
624
742
  case LLM_FFN_GELU:
625
- {
743
+ if (gate && type_gate == LLM_FFN_PAR) {
744
+ cur = ggml_geglu_split(ctx0, cur, tmp);
745
+ cb(cur, "ffn_geglu", il);
746
+ type_gate = LLM_FFN_SEQ;
747
+ } else {
626
748
  cur = ggml_gelu(ctx0, cur);
627
749
  cb(cur, "ffn_gelu", il);
628
750
  if (act_scales != NULL) {
@@ -631,7 +753,11 @@ ggml_tensor * llm_graph_context::build_ffn(
631
753
  }
632
754
  } break;
633
755
  case LLM_FFN_RELU:
634
- {
756
+ if (gate && type_gate == LLM_FFN_PAR) {
757
+ cur = ggml_reglu_split(ctx0, cur, tmp);
758
+ cb(cur, "ffn_reglu", il);
759
+ type_gate = LLM_FFN_SEQ;
760
+ } else {
635
761
  cur = ggml_relu(ctx0, cur);
636
762
  cb(cur, "ffn_relu", il);
637
763
  } break;
@@ -645,17 +771,21 @@ ggml_tensor * llm_graph_context::build_ffn(
645
771
  } break;
646
772
  case LLM_FFN_SWIGLU:
647
773
  {
648
- // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
649
- int64_t split_point = cur->ne[0] / 2;
650
- ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
651
- ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
652
-
653
- x0 = ggml_silu(ctx0, x0);
654
- cb(cur, "ffn_silu", il);
655
-
656
- cur = ggml_mul(ctx0, x0, x1);
657
- cb(cur, "ffn_mul", il);
774
+ cur = ggml_swiglu(ctx0, cur);
775
+ cb(cur, "ffn_swiglu", il);
658
776
  } break;
777
+ case LLM_FFN_GEGLU:
778
+ {
779
+ cur = ggml_geglu(ctx0, cur);
780
+ cb(cur, "ffn_geglu", il);
781
+ } break;
782
+ case LLM_FFN_REGLU:
783
+ {
784
+ cur = ggml_reglu(ctx0, cur);
785
+ cb(cur, "ffn_reglu", il);
786
+ } break;
787
+ default:
788
+ GGML_ABORT("fatal error");
659
789
  }
660
790
 
661
791
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -665,8 +795,8 @@ ggml_tensor * llm_graph_context::build_ffn(
665
795
 
666
796
  if (down) {
667
797
  cur = build_lora_mm(down, cur);
668
- if (arch == LLM_ARCH_GLM4) {
669
- // GLM4 seems to have numerical issues with half-precision accumulators
798
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
799
+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
670
800
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
671
801
  }
672
802
  }
@@ -701,13 +831,64 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
701
831
  bool scale_w,
702
832
  float w_scale,
703
833
  llama_expert_gating_func_type gating_op,
704
- int il) const {
834
+ int il,
835
+ ggml_tensor * probs_in) const {
836
+ return build_moe_ffn(
837
+ cur,
838
+ gate_inp, /* gate_inp_b */ nullptr,
839
+ up_exps, /* up_exps_b */ nullptr,
840
+ gate_exps, /* gate_exps_b */ nullptr,
841
+ down_exps, /* down_exps_b */ nullptr,
842
+ exp_probs_b,
843
+ n_expert,
844
+ n_expert_used,
845
+ type_op,
846
+ norm_w,
847
+ scale_w,
848
+ w_scale,
849
+ gating_op,
850
+ il,
851
+ probs_in
852
+ );
853
+ }
854
+
855
+ ggml_tensor * llm_graph_context::build_moe_ffn(
856
+ ggml_tensor * cur,
857
+ ggml_tensor * gate_inp,
858
+ ggml_tensor * gate_inp_b,
859
+ ggml_tensor * up_exps,
860
+ ggml_tensor * up_exps_b,
861
+ ggml_tensor * gate_exps,
862
+ ggml_tensor * gate_exps_b,
863
+ ggml_tensor * down_exps,
864
+ ggml_tensor * down_exps_b,
865
+ ggml_tensor * exp_probs_b,
866
+ int64_t n_expert,
867
+ int64_t n_expert_used,
868
+ llm_ffn_op_type type_op,
869
+ bool norm_w,
870
+ bool scale_w,
871
+ float w_scale,
872
+ llama_expert_gating_func_type gating_op,
873
+ int il,
874
+ ggml_tensor * probs_in) const {
705
875
  const int64_t n_embd = cur->ne[0];
706
876
  const int64_t n_tokens = cur->ne[1];
707
877
  const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
708
878
 
709
- ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
710
- cb(logits, "ffn_moe_logits", il);
879
+ ggml_tensor * logits = nullptr;
880
+
881
+ if (probs_in == nullptr) {
882
+ logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
883
+ cb(logits, "ffn_moe_logits", il);
884
+ } else {
885
+ logits = probs_in;
886
+ }
887
+
888
+ if (gate_inp_b) {
889
+ logits = ggml_add(ctx0, logits, gate_inp_b);
890
+ cb(logits, "ffn_moe_logits_biased", il);
891
+ }
711
892
 
712
893
  ggml_tensor * probs = nullptr;
713
894
  switch (gating_op) {
@@ -719,6 +900,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
719
900
  {
720
901
  probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
721
902
  } break;
903
+ case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
904
+ {
905
+ probs = logits; // [n_expert, n_tokens]
906
+ } break;
722
907
  default:
723
908
  GGML_ABORT("fatal error");
724
909
  }
@@ -738,15 +923,36 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
738
923
  selection_probs = logits;
739
924
  }
740
925
 
926
+ if (arch == LLM_ARCH_GROVEMOE) {
927
+ selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
928
+ cb(selection_probs, "ffn_moe_probs_biased", il);
929
+ }
930
+
741
931
  // select experts
742
932
  ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
743
933
  cb(selected_experts->src[0], "ffn_moe_argsort", il);
744
934
  cb(selected_experts, "ffn_moe_topk", il);
745
935
 
746
- ggml_tensor * weights = ggml_get_rows(ctx0,
747
- ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
936
+ if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
937
+ // TODO: Use scalar div instead when/if implemented
938
+ ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
939
+ selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
940
+ probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
941
+ } else {
942
+ probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
943
+ }
944
+
945
+ ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
748
946
  cb(weights, "ffn_moe_weights", il);
749
947
 
948
+
949
+ if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
950
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
951
+ weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
952
+ weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
953
+ cb(weights, "ffn_moe_weights_softmax", il);
954
+ }
955
+
750
956
  if (norm_w) {
751
957
  weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
752
958
 
@@ -763,12 +969,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
763
969
  cb(weights, "ffn_moe_weights_scaled", il);
764
970
  }
765
971
 
972
+ //call early so that topk-moe can be used
973
+ ggml_build_forward_expand(gf, weights);
974
+
766
975
  cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
767
976
 
768
977
  if (weight_before_ffn) {
769
- // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
770
- ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
771
- repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
978
+ // repeat cur to [n_embd, n_expert_used, n_tokens]
979
+ ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
772
980
  cur = ggml_mul(ctx0, repeated, weights);
773
981
  cb(cur, "ffn_moe_weighted", il);
774
982
  }
@@ -776,6 +984,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
776
984
  ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
777
985
  cb(up, "ffn_moe_up", il);
778
986
 
987
+ if (up_exps_b) {
988
+ up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
989
+ cb(up, "ffn_moe_up_biased", il);
990
+ }
991
+
779
992
  ggml_tensor * experts = nullptr;
780
993
  if (gate_exps) {
781
994
  cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
@@ -784,48 +997,83 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
784
997
  cur = up;
785
998
  }
786
999
 
1000
+ if (gate_exps_b) {
1001
+ cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
1002
+ cb(cur, "ffn_moe_gate_biased", il);
1003
+ }
1004
+
787
1005
  switch (type_op) {
788
1006
  case LLM_FFN_SILU:
789
- {
1007
+ if (gate_exps) {
1008
+ cur = ggml_swiglu_split(ctx0, cur, up);
1009
+ cb(cur, "ffn_moe_swiglu", il);
1010
+ } else {
790
1011
  cur = ggml_silu(ctx0, cur);
791
1012
  cb(cur, "ffn_moe_silu", il);
792
1013
  } break;
793
1014
  case LLM_FFN_GELU:
794
- {
1015
+ if (gate_exps) {
1016
+ cur = ggml_geglu_split(ctx0, cur, up);
1017
+ cb(cur, "ffn_moe_geglu", il);
1018
+ } else {
795
1019
  cur = ggml_gelu(ctx0, cur);
796
1020
  cb(cur, "ffn_moe_gelu", il);
797
1021
  } break;
1022
+ case LLM_FFN_SWIGLU_OAI_MOE:
1023
+ {
1024
+ // TODO: move to hparams?
1025
+ constexpr float alpha = 1.702f;
1026
+ constexpr float limit = 7.0f;
1027
+ cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
1028
+ cb(cur, "ffn_moe_swiglu_oai", il);
1029
+ } break;
1030
+ case LLM_FFN_RELU:
1031
+ if (gate_exps) {
1032
+ cur = ggml_reglu_split(ctx0, cur, up);
1033
+ cb(cur, "ffn_moe_reglu", il);
1034
+ } else {
1035
+ cur = ggml_relu(ctx0, cur);
1036
+ cb(cur, "ffn_moe_relu", il);
1037
+ } break;
798
1038
  default:
799
1039
  GGML_ABORT("fatal error");
800
1040
  }
801
1041
 
802
- if (gate_exps) {
803
- cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
804
- cb(cur, "ffn_moe_gate_par", il);
805
- }
806
-
807
1042
  experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
808
1043
  cb(experts, "ffn_moe_down", il);
809
1044
 
1045
+ if (down_exps_b) {
1046
+ experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
1047
+ cb(experts, "ffn_moe_down_biased", il);
1048
+ }
1049
+
810
1050
  if (!weight_before_ffn) {
811
1051
  experts = ggml_mul(ctx0, experts, weights);
812
1052
  cb(cur, "ffn_moe_weighted", il);
813
1053
  }
814
1054
 
1055
+ ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1056
+
1057
+ assert(n_expert_used > 0);
1058
+
1059
+ // order the views before the adds
1060
+ for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
1061
+ cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
1062
+
1063
+ ggml_build_forward_expand(gf, cur_experts[i]);
1064
+ }
1065
+
815
1066
  // aggregate experts
816
- ggml_tensor * moe_out = nullptr;
817
- for (int i = 0; i < n_expert_used; ++i) {
818
- ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
819
- experts->nb[2], i*experts->nb[1]);
1067
+ // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1068
+ // to avoid potentially a large number of add nodes during warmup
1069
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14753
1070
+ ggml_tensor * moe_out = cur_experts[0];
820
1071
 
821
- if (i == 0) {
822
- moe_out = cur_expert;
823
- } else {
824
- moe_out = ggml_add(ctx0, moe_out, cur_expert);
825
- }
1072
+ for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1073
+ moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
826
1074
  }
827
1075
 
828
- if (n_expert_used == 1) {
1076
+ if (hparams.n_expert_used == 1) {
829
1077
  // avoid returning a non-contiguous tensor
830
1078
  moe_out = ggml_cont(ctx0, moe_out);
831
1079
  }
@@ -888,11 +1136,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
888
1136
  }
889
1137
 
890
1138
  ggml_tensor * llm_graph_context::build_inp_pos() const {
891
- auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
1139
+ auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
892
1140
 
893
1141
  auto & cur = inp->pos;
894
1142
 
895
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
1143
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
896
1144
  ggml_set_input(cur);
897
1145
 
898
1146
  res->add_input(std::move(inp));
@@ -915,6 +1163,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
915
1163
  }
916
1164
 
917
1165
  ggml_tensor * llm_graph_context::build_inp_out_ids() const {
1166
+ // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
1167
+ // but this would make the graph topology depend on the number of output tokens, which can interere with
1168
+ // features that require constant topology such as pipline parallelism
1169
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
1170
+ //if (n_outputs < n_tokens) {
1171
+ // return nullptr;
1172
+ //}
1173
+
918
1174
  auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
919
1175
 
920
1176
  auto & cur = inp->out_ids;
@@ -932,7 +1188,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
932
1188
 
933
1189
  auto & cur = inp->mean;
934
1190
 
935
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
1191
+ cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
936
1192
  ggml_set_input(cur);
937
1193
 
938
1194
  res->add_input(std::move(inp));
@@ -941,45 +1197,11 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
941
1197
  }
942
1198
 
943
1199
  ggml_tensor * llm_graph_context::build_inp_cls() const {
944
- auto inp = std::make_unique<llm_graph_input_cls>(cparams);
1200
+ auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
945
1201
 
946
1202
  auto & cur = inp->cls;
947
1203
 
948
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
949
- ggml_set_input(cur);
950
-
951
- res->add_input(std::move(inp));
952
-
953
- return cur;
954
- }
955
-
956
- ggml_tensor * llm_graph_context::build_inp_s_copy() const {
957
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
958
-
959
- auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
960
-
961
- const auto n_kv = kv_self->n;
962
-
963
- auto & cur = inp->s_copy;
964
-
965
- cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
966
- ggml_set_input(cur);
967
-
968
- res->add_input(std::move(inp));
969
-
970
- return cur;
971
- }
972
-
973
- ggml_tensor * llm_graph_context::build_inp_s_mask() const {
974
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
975
-
976
- auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
977
-
978
- const auto n_kv = kv_self->n;
979
-
980
- auto & cur = inp->s_mask;
981
-
982
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
1204
+ cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
983
1205
  ggml_set_input(cur);
984
1206
 
985
1207
  res->add_input(std::move(inp));
@@ -1025,11 +1247,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1025
1247
  }
1026
1248
 
1027
1249
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1028
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1250
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1029
1251
 
1030
- auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
1252
+ auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
1031
1253
 
1032
- const auto n_kv = kv_self->get_n();
1254
+ const auto n_kv = mctx_cur->get_n_kv();
1033
1255
 
1034
1256
  auto & cur = inp->pos_bucket;
1035
1257
 
@@ -1057,23 +1279,27 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
1057
1279
  }
1058
1280
 
1059
1281
  ggml_tensor * llm_graph_context::build_attn_mha(
1060
- ggml_cgraph * gf,
1061
1282
  ggml_tensor * q,
1062
1283
  ggml_tensor * k,
1063
1284
  ggml_tensor * v,
1064
1285
  ggml_tensor * kq_b,
1065
1286
  ggml_tensor * kq_mask,
1287
+ ggml_tensor * sinks,
1066
1288
  ggml_tensor * v_mla,
1067
- float kq_scale) const {
1289
+ float kq_scale,
1290
+ int il) const {
1068
1291
  const bool v_trans = v->nb[1] > v->nb[2];
1069
1292
 
1293
+ // split the batch into streams if needed
1294
+ const auto n_stream = k->ne[3];
1295
+
1296
+ q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
1297
+
1070
1298
  q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1071
1299
  k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1072
1300
  v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1073
1301
 
1074
- const auto n_tokens = q->ne[1];
1075
- const auto n_head = q->ne[2];
1076
- const auto n_kv = k->ne[1];
1302
+ const auto n_kv = k->ne[1];
1077
1303
 
1078
1304
  ggml_tensor * cur;
1079
1305
 
@@ -1096,8 +1322,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1096
1322
 
1097
1323
  cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1098
1324
  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1325
+ cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
1099
1326
 
1100
- ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
1327
+ ggml_flash_attn_ext_add_sinks(cur, sinks);
1328
+ ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
1101
1329
 
1102
1330
  if (v_mla) {
1103
1331
  #if 0
@@ -1110,14 +1338,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1110
1338
  // The permutations are noops and only change how the tensor data is interpreted.
1111
1339
  cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1112
1340
  cur = ggml_mul_mat(ctx0, v_mla, cur);
1341
+ cb(cur, "fattn_mla", il);
1113
1342
  cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1114
1343
  cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1115
1344
  #endif
1116
1345
  }
1117
1346
 
1118
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1347
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1119
1348
  } else {
1120
1349
  ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1350
+ cb(kq, "kq", il);
1121
1351
 
1122
1352
  // note: this op tends to require high floating point range
1123
1353
  // while for some models F16 is enough, for others it is not, so we default to F32 here
@@ -1125,42 +1355,54 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1125
1355
 
1126
1356
  if (arch == LLM_ARCH_GROK) {
1127
1357
  // need to do the following:
1128
- // multiply by attn_output_multiplyer of 0.08838834764831845
1358
+ // multiply by attn_output_multiplier
1129
1359
  // and then :
1130
1360
  // kq = 30 * tanh(kq / 30)
1131
1361
  // before the softmax below
1132
1362
 
1133
- kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
1134
- kq = ggml_scale(ctx0, kq, 30);
1363
+ kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
1364
+ cb(kq, "kq_tanh", il);
1365
+ kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1366
+ cb(kq, "kq_scaled", il);
1135
1367
  }
1136
1368
 
1137
1369
  if (hparams.attn_soft_cap) {
1138
1370
  kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
1371
+ cb(kq, "kq_scaled_1", il);
1139
1372
  kq = ggml_tanh (ctx0, kq);
1373
+ cb(kq, "kq_tanh", il);
1140
1374
  kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1375
+ cb(kq, "kq_scaled_2", il);
1141
1376
  }
1142
1377
 
1143
1378
  if (kq_b) {
1144
1379
  kq = ggml_add(ctx0, kq, kq_b);
1380
+ cb(kq, "kq_plus_kq_b", il);
1145
1381
  }
1146
1382
 
1147
1383
  kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
1384
+ ggml_soft_max_add_sinks(kq, sinks);
1385
+ cb(kq, "kq_soft_max", il);
1148
1386
 
1149
1387
  if (!v_trans) {
1150
1388
  // note: avoid this branch
1151
1389
  v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
1390
+ cb(v, "v_cont", il);
1152
1391
  }
1153
1392
 
1154
1393
  ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
1394
+ cb(kqv, "kqv", il);
1155
1395
 
1156
1396
  // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1157
1397
  if (v_mla) {
1158
1398
  kqv = ggml_mul_mat(ctx0, v_mla, kqv);
1399
+ cb(kqv, "kqv_mla", il);
1159
1400
  }
1160
1401
 
1161
1402
  cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1162
1403
 
1163
- cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1404
+ // recombine streams
1405
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1164
1406
 
1165
1407
  if (!cparams.offload_kqv) {
1166
1408
  // all nodes between the KV store and the attention output are run on the CPU
@@ -1177,8 +1419,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1177
1419
  auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1178
1420
 
1179
1421
  // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1180
- inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1181
- //cb(inp_kq_mask, "KQ_mask", -1);
1422
+ inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1182
1423
  ggml_set_input(inp->kq_mask);
1183
1424
 
1184
1425
  inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
@@ -1188,13 +1429,13 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1188
1429
 
1189
1430
  ggml_tensor * llm_graph_context::build_attn(
1190
1431
  llm_graph_input_attn_no_cache * inp,
1191
- ggml_cgraph * gf,
1192
1432
  ggml_tensor * wo,
1193
1433
  ggml_tensor * wo_b,
1194
1434
  ggml_tensor * q_cur,
1195
1435
  ggml_tensor * k_cur,
1196
1436
  ggml_tensor * v_cur,
1197
1437
  ggml_tensor * kq_b,
1438
+ ggml_tensor * sinks,
1198
1439
  ggml_tensor * v_mla,
1199
1440
  float kq_scale,
1200
1441
  int il) const {
@@ -1208,11 +1449,16 @@ ggml_tensor * llm_graph_context::build_attn(
1208
1449
 
1209
1450
  const auto & kq_mask = inp->get_kq_mask();
1210
1451
 
1452
+ // [TAG_NO_CACHE_PAD]
1453
+ // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1454
+ // but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
1455
+ //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
1456
+
1211
1457
  ggml_tensor * q = q_cur;
1212
1458
  ggml_tensor * k = k_cur;
1213
1459
  ggml_tensor * v = v_cur;
1214
1460
 
1215
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1461
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1216
1462
  cb(cur, "kqv_out", il);
1217
1463
 
1218
1464
  if (wo) {
@@ -1230,35 +1476,51 @@ ggml_tensor * llm_graph_context::build_attn(
1230
1476
  return cur;
1231
1477
  }
1232
1478
 
1233
- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1234
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1479
+ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1480
+ ggml_context * ctx0,
1481
+ const llama_ubatch & ubatch,
1482
+ const llama_hparams & hparams,
1483
+ const llama_cparams & cparams,
1484
+ const llama_kv_cache_context * mctx_cur) {
1235
1485
 
1236
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1486
+ auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
1237
1487
 
1238
1488
  {
1239
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1489
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1490
+
1491
+ const auto n_kv = mctx_cur->get_n_kv();
1492
+ const auto n_tokens = ubatch.n_tokens;
1493
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1240
1494
 
1241
- const auto n_kv = kv_self->get_n();
1495
+ inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1496
+ inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1242
1497
 
1243
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1244
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1498
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1245
1499
  ggml_set_input(inp->self_kq_mask);
1246
1500
 
1247
1501
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1248
1502
  }
1249
1503
 
1250
- return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1504
+ return inp;
1505
+ }
1506
+
1507
+ llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
1508
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1509
+
1510
+ auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1511
+
1512
+ return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
1251
1513
  }
1252
1514
 
1253
1515
  ggml_tensor * llm_graph_context::build_attn(
1254
- llm_graph_input_attn_kv_unified * inp,
1255
- ggml_cgraph * gf,
1516
+ llm_graph_input_attn_kv * inp,
1256
1517
  ggml_tensor * wo,
1257
1518
  ggml_tensor * wo_b,
1258
1519
  ggml_tensor * q_cur,
1259
1520
  ggml_tensor * k_cur,
1260
1521
  ggml_tensor * v_cur,
1261
1522
  ggml_tensor * kq_b,
1523
+ ggml_tensor * sinks,
1262
1524
  ggml_tensor * v_mla,
1263
1525
  float kq_scale,
1264
1526
  int il) const {
@@ -1268,27 +1530,30 @@ ggml_tensor * llm_graph_context::build_attn(
1268
1530
  ggml_build_forward_expand(gf, k_cur);
1269
1531
  ggml_build_forward_expand(gf, v_cur);
1270
1532
 
1271
- const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1533
+ const auto * mctx_cur = inp->mctx;
1272
1534
 
1273
1535
  // store to KV cache
1274
1536
  {
1275
- ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1276
- ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1537
+ const auto & k_idxs = inp->get_k_idxs();
1538
+ const auto & v_idxs = inp->get_v_idxs();
1539
+
1540
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1541
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1277
1542
  }
1278
1543
 
1279
1544
  const auto & kq_mask = inp->get_kq_mask();
1280
1545
 
1281
1546
  ggml_tensor * q = q_cur;
1282
- ggml_tensor * k = kv_self->get_k(ctx0, il);
1283
- ggml_tensor * v = kv_self->get_v(ctx0, il);
1547
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1548
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1284
1549
 
1285
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1550
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1286
1551
  cb(cur, "kqv_out", il);
1287
1552
 
1288
1553
  if (wo) {
1289
1554
  cur = build_lora_mm(wo, cur);
1290
- if (arch == LLM_ARCH_GLM4) {
1291
- // GLM4 seems to have numerical issues with half-precision accumulators
1555
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1556
+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1292
1557
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1293
1558
  }
1294
1559
  }
@@ -1300,73 +1565,56 @@ ggml_tensor * llm_graph_context::build_attn(
1300
1565
  return cur;
1301
1566
  }
1302
1567
 
1303
- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1304
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1305
-
1306
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1307
-
1308
- {
1309
- const auto n_kv = kv_self->get_kv_base()->get_n();
1310
-
1311
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1312
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1313
- ggml_set_input(inp->self_kq_mask);
1314
-
1315
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1316
- }
1317
-
1318
- {
1319
- GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1320
-
1321
- const auto n_kv = kv_self->get_kv_swa()->get_n();
1322
-
1323
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1324
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1325
- ggml_set_input(inp->self_kq_mask_swa);
1326
-
1327
- inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1328
- }
1329
-
1330
- return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1331
- }
1332
-
1333
1568
  ggml_tensor * llm_graph_context::build_attn(
1334
- llm_graph_input_attn_kv_unified_iswa * inp,
1335
- ggml_cgraph * gf,
1569
+ llm_graph_input_attn_kv_iswa * inp,
1336
1570
  ggml_tensor * wo,
1337
1571
  ggml_tensor * wo_b,
1338
1572
  ggml_tensor * q_cur,
1339
1573
  ggml_tensor * k_cur,
1340
1574
  ggml_tensor * v_cur,
1341
1575
  ggml_tensor * kq_b,
1576
+ ggml_tensor * sinks,
1342
1577
  ggml_tensor * v_mla,
1343
1578
  float kq_scale,
1344
1579
  int il) const {
1345
1580
  // these nodes are added to the graph together so that they are not reordered
1346
1581
  // by doing so, the number of splits in the graph is reduced
1347
1582
  ggml_build_forward_expand(gf, q_cur);
1348
- ggml_build_forward_expand(gf, k_cur);
1349
- ggml_build_forward_expand(gf, v_cur);
1583
+
1584
+ if (k_cur) {
1585
+ ggml_build_forward_expand(gf, k_cur);
1586
+ }
1587
+
1588
+ if (v_cur) {
1589
+ ggml_build_forward_expand(gf, v_cur);
1590
+ }
1591
+
1592
+ const auto * mctx_iswa = inp->mctx;
1350
1593
 
1351
1594
  const bool is_swa = hparams.is_swa(il);
1352
1595
 
1353
- const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
1596
+ const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
1354
1597
 
1355
- const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
1598
+ // optionally store to KV cache
1599
+ if (k_cur) {
1600
+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
1356
1601
 
1357
- // store to KV cache
1358
- {
1359
- ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
1360
- ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
1602
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1603
+ }
1604
+
1605
+ if (v_cur) {
1606
+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
1607
+
1608
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1361
1609
  }
1362
1610
 
1363
1611
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1364
1612
 
1365
1613
  ggml_tensor * q = q_cur;
1366
- ggml_tensor * k = kv->get_k(ctx0, il);
1367
- ggml_tensor * v = kv->get_v(ctx0, il);
1614
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1615
+ ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1368
1616
 
1369
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1617
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1370
1618
  cb(cur, "kqv_out", il);
1371
1619
 
1372
1620
  if (wo) {
@@ -1389,7 +1637,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1389
1637
 
1390
1638
  const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1391
1639
 
1392
- inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1640
+ inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1393
1641
  ggml_set_input(inp->cross_kq_mask);
1394
1642
 
1395
1643
  inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1399,13 +1647,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1399
1647
 
1400
1648
  ggml_tensor * llm_graph_context::build_attn(
1401
1649
  llm_graph_input_attn_cross * inp,
1402
- ggml_cgraph * gf,
1403
1650
  ggml_tensor * wo,
1404
1651
  ggml_tensor * wo_b,
1405
1652
  ggml_tensor * q_cur,
1406
1653
  ggml_tensor * k_cur,
1407
1654
  ggml_tensor * v_cur,
1408
1655
  ggml_tensor * kq_b,
1656
+ ggml_tensor * sinks,
1409
1657
  ggml_tensor * v_mla,
1410
1658
  float kq_scale,
1411
1659
  int il) const {
@@ -1421,7 +1669,7 @@ ggml_tensor * llm_graph_context::build_attn(
1421
1669
  ggml_tensor * k = k_cur;
1422
1670
  ggml_tensor * v = v_cur;
1423
1671
 
1424
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1672
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1425
1673
  cb(cur, "kqv_out", il);
1426
1674
 
1427
1675
  if (wo) {
@@ -1439,56 +1687,135 @@ ggml_tensor * llm_graph_context::build_attn(
1439
1687
  return cur;
1440
1688
  }
1441
1689
 
1442
- ggml_tensor * llm_graph_context::build_copy_mask_state(
1443
- ggml_cgraph * gf,
1444
- ggml_tensor * s,
1445
- ggml_tensor * state_copy,
1446
- ggml_tensor * state_mask,
1447
- int32_t n_state,
1448
- int32_t n_seqs) const {
1449
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1690
+ // TODO: maybe separate the inner implementation into a separate function
1691
+ // like with the non-sliding window equivalent
1692
+ // once sliding-window hybrid caches are a thing.
1693
+ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
1694
+ const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
1450
1695
 
1451
- const auto n_kv = kv_self->n;
1452
- const auto kv_head = kv_self->head;
1696
+ auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
1453
1697
 
1454
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
1698
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1455
1699
 
1456
- // copy states
1457
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1458
- // this shrinks the tensors's ne[1] to n_kv
1459
- states = ggml_get_rows(ctx0, states, state_copy);
1700
+ {
1701
+ const auto n_kv = mctx_cur->get_base()->get_n_kv();
1460
1702
 
1461
- // clear states of sequences which are starting at the beginning of this batch
1462
- // FIXME: zero-out NANs?
1463
- states = ggml_mul(ctx0, states, state_mask);
1703
+ inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1704
+ inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1705
+
1706
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1707
+ ggml_set_input(inp->self_kq_mask);
1464
1708
 
1465
- // copy states which won't be changed further (between n_seqs and n_kv)
1709
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1710
+ }
1711
+
1712
+ {
1713
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
1714
+
1715
+ const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1716
+
1717
+ inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1718
+ inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1719
+
1720
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1721
+ ggml_set_input(inp->self_kq_mask_swa);
1722
+
1723
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1724
+ }
1725
+
1726
+ return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
1727
+ }
1728
+
1729
+ ggml_tensor * llm_graph_context::build_rs(
1730
+ ggml_tensor * s,
1731
+ ggml_tensor * state_copy_main,
1732
+ ggml_tensor * state_copy_extra,
1733
+ int32_t state_size,
1734
+ int32_t n_seqs,
1735
+ uint32_t n_rs,
1736
+ uint32_t rs_head,
1737
+ uint32_t rs_size,
1738
+ int32_t rs_zero,
1739
+ const llm_graph_get_rows_fn & get_state_rows) const {
1740
+
1741
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
1742
+
1743
+ // Clear a single state which will then be copied to the other cleared states.
1744
+ // Note that this is a no-op when the view is zero-sized.
1745
+ ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1746
+ ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1747
+
1748
+ // copy states
1749
+ // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
1750
+ // {state_size, rs_size} -> {state_size, n_seqs}
1751
+ ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
1752
+ ggml_build_forward_expand(gf, output_states);
1753
+
1754
+ // copy extra states which won't be changed further (between n_seqs and n_rs)
1755
+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
1466
1756
  ggml_build_forward_expand(gf,
1467
1757
  ggml_cpy(ctx0,
1468
- ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
1469
- ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1758
+ states_extra,
1759
+ ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
1760
+
1761
+ return output_states;
1762
+ }
1763
+
1764
+ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
1765
+ ggml_context * ctx0,
1766
+ const llama_ubatch & ubatch,
1767
+ const llama_memory_recurrent_context * mctx_cur) {
1768
+
1769
+ auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1770
+
1771
+ const int64_t n_rs = mctx_cur->get_n_rs();
1772
+ const int64_t n_seqs = ubatch.n_seqs;
1773
+
1774
+ inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1775
+ ggml_set_input(inp->s_copy);
1776
+
1777
+ inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
1778
+ inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
1470
1779
 
1471
- // the part of the states that will be used and modified
1472
- return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
1780
+ return inp;
1781
+ }
1782
+
1783
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1784
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1785
+
1786
+ auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
1787
+
1788
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1789
+ }
1790
+
1791
+ ggml_tensor * llm_graph_context::build_rs(
1792
+ llm_graph_input_rs * inp,
1793
+ ggml_tensor * s,
1794
+ int32_t state_size,
1795
+ int32_t n_seqs,
1796
+ const llm_graph_get_rows_fn & get_state_rows) const {
1797
+ const auto * kv_state = inp->mctx;
1798
+
1799
+ return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
1800
+ kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
1801
+ get_state_rows);
1473
1802
  }
1474
1803
 
1475
1804
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1476
- ggml_cgraph * gf,
1477
- ggml_tensor * state_copy,
1478
- ggml_tensor * state_mask,
1479
- const llama_ubatch & ubatch,
1480
- int il) const {
1481
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1805
+ llm_graph_input_rs * inp,
1806
+ const llama_ubatch & ubatch,
1807
+ int il) const {
1808
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1482
1809
 
1483
1810
  const auto token_shift_count = hparams.token_shift_count;
1484
1811
 
1485
1812
  const int64_t n_seqs = ubatch.n_seqs;
1486
1813
 
1487
- ggml_tensor * token_shift_all = kv_self->k_l[il];
1814
+ ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1488
1815
 
1489
- ggml_tensor * token_shift = build_copy_mask_state(
1490
- gf, token_shift_all, state_copy, state_mask,
1491
- hparams.n_embd_k_s(), n_seqs);
1816
+ ggml_tensor * token_shift = build_rs(
1817
+ inp, token_shift_all,
1818
+ hparams.n_embd_r(), n_seqs);
1492
1819
 
1493
1820
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
1494
1821
 
@@ -1499,24 +1826,34 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1499
1826
  ggml_tensor * token_shift,
1500
1827
  const llama_ubatch & ubatch,
1501
1828
  int il) const {
1502
- const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1829
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1503
1830
 
1504
1831
  const auto token_shift_count = hparams.token_shift_count;
1505
1832
  const auto n_embd = hparams.n_embd;
1506
1833
 
1507
1834
  const int64_t n_seqs = ubatch.n_seqs;
1508
1835
 
1509
- const auto kv_head = kv_self->head;
1836
+ const auto kv_head = mctx_cur->get_head();
1510
1837
 
1511
1838
  return ggml_cpy(
1512
1839
  ctx0,
1513
1840
  ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1514
- ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
1841
+ ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
1515
1842
  );
1516
1843
  }
1517
1844
 
1845
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1846
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1847
+
1848
+ auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
1849
+ auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1850
+
1851
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
1852
+
1853
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1854
+ }
1855
+
1518
1856
  void llm_graph_context::build_pooling(
1519
- ggml_cgraph * gf,
1520
1857
  ggml_tensor * cls,
1521
1858
  ggml_tensor * cls_b,
1522
1859
  ggml_tensor * cls_out,
@@ -1560,22 +1897,32 @@ void llm_graph_context::build_pooling(
1560
1897
  case LLAMA_POOLING_TYPE_RANK:
1561
1898
  {
1562
1899
  ggml_tensor * inp_cls = build_inp_cls();
1563
- inp = ggml_get_rows(ctx0, inp, inp_cls);
1900
+ cur = ggml_get_rows(ctx0, inp, inp_cls);
1564
1901
 
1565
1902
  // classification head
1566
1903
  // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1567
- GGML_ASSERT(cls != nullptr);
1568
- GGML_ASSERT(cls_b != nullptr);
1569
-
1570
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1571
- cur = ggml_tanh(ctx0, cur);
1904
+ if (cls) {
1905
+ cur = ggml_mul_mat(ctx0, cls, cur);
1906
+ if (cls_b) {
1907
+ cur = ggml_add(ctx0, cur, cls_b);
1908
+ }
1909
+ cur = ggml_tanh(ctx0, cur);
1910
+ }
1572
1911
 
1573
1912
  // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1574
1913
  // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1914
+ // Single layer classification head (direct projection)
1915
+ // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1575
1916
  if (cls_out) {
1576
- GGML_ASSERT(cls_out_b != nullptr);
1917
+ cur = ggml_mul_mat(ctx0, cls_out, cur);
1918
+ if (cls_out_b) {
1919
+ cur = ggml_add(ctx0, cur, cls_out_b);
1920
+ }
1921
+ }
1577
1922
 
1578
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1923
+ // softmax for qwen3 reranker
1924
+ if (arch == LLM_ARCH_QWEN3) {
1925
+ cur = ggml_soft_max(ctx0, cur);
1579
1926
  }
1580
1927
  } break;
1581
1928
  default: