whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -4,8 +4,8 @@
4
4
  #include "llama-batch.h"
5
5
  #include "llama-cparams.h"
6
6
 
7
- #include "llama-kv-cache-unified.h"
8
- #include "llama-kv-cache-unified-iswa.h"
7
+ #include "llama-kv-cache.h"
8
+ #include "llama-kv-cache-iswa.h"
9
9
  #include "llama-memory-hybrid.h"
10
10
  #include "llama-memory-recurrent.h"
11
11
 
@@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
28
28
  }
29
29
  }
30
30
 
31
+ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
32
+ bool res = true;
33
+
34
+ res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
35
+ res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
36
+
37
+ return res;
38
+ }
39
+
31
40
  void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
32
41
  if (ubatch->pos && pos) {
33
42
  const int64_t n_tokens = ubatch->n_tokens;
@@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
50
59
  }
51
60
  }
52
61
 
62
+ bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
63
+ bool res = true;
64
+
65
+ res &= pos->ne[0] == params.ubatch.n_tokens;
66
+
67
+ return res;
68
+ }
69
+
53
70
  void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
54
71
  if (ubatch->pos && attn_scale) {
55
72
  const int64_t n_tokens = ubatch->n_tokens;
@@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
71
88
  const int64_t n_tokens = ubatch->n_tokens;
72
89
 
73
90
  GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
74
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
91
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
75
92
 
76
93
  int32_t * data = (int32_t *) pos_bucket->data;
77
94
 
@@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
118
135
  }
119
136
  }
120
137
 
138
+ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
139
+ bool res = true;
140
+
141
+ res &= n_outputs == params.n_outputs;
142
+
143
+ return res;
144
+ }
145
+
121
146
  void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
122
147
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
123
148
  const int64_t n_tokens = ubatch->n_tokens;
@@ -163,38 +188,26 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
163
188
 
164
189
  void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
165
190
  const int64_t n_tokens = ubatch->n_tokens;
166
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
167
191
  const int64_t n_seqs_unq = ubatch->n_seqs_unq;
168
192
 
169
193
  if (cparams.embeddings && (
170
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
171
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
172
- )) {
194
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
195
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
196
+ cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
197
+ )) {
173
198
  GGML_ASSERT(cls);
174
199
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
175
200
 
176
201
  uint32_t * data = (uint32_t *) cls->data;
177
202
  memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
178
203
 
179
- for (int i = 0; i < n_tokens; i += n_seq_tokens) {
180
- for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
181
- const llama_seq_id seq_id = ubatch->seq_id[i][s];
182
- const int32_t seq_idx = ubatch->seq_idx[seq_id];
183
-
184
- data[seq_idx] = i;
185
- }
186
- }
187
- }
204
+ std::vector<int> target_pos(n_seqs_unq, -1);
205
+ std::vector<int> target_row(n_seqs_unq, -1);
188
206
 
189
- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
190
- GGML_ASSERT(cls);
191
- GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
192
-
193
- uint32_t * data = (uint32_t *) cls->data;
194
- memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
195
-
196
- std::vector<int> last_pos(n_seqs_unq, -1);
197
- std::vector<int> last_row(n_seqs_unq, -1);
207
+ const bool last = (
208
+ cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
209
+ (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
210
+ );
198
211
 
199
212
  for (int i = 0; i < n_tokens; ++i) {
200
213
  const llama_pos pos = ubatch->pos[i];
@@ -203,16 +216,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
203
216
  const llama_seq_id seq_id = ubatch->seq_id[i][s];
204
217
  const int32_t seq_idx = ubatch->seq_idx[seq_id];
205
218
 
206
- if (pos >= last_pos[seq_idx]) {
207
- last_pos[seq_idx] = pos;
208
- last_row[seq_idx] = i;
219
+ if (
220
+ (target_pos[seq_idx] == -1) ||
221
+ ( last && pos >= target_pos[seq_idx]) ||
222
+ (!last && pos < target_pos[seq_idx])
223
+ ) {
224
+ target_pos[seq_idx] = pos;
225
+ target_row[seq_idx] = i;
209
226
  }
210
227
  }
211
228
  }
212
229
 
213
230
  for (int s = 0; s < n_seqs_unq; ++s) {
214
- if (last_row[s] >= 0) {
215
- data[s] = last_row[s];
231
+ if (target_row[s] >= 0) {
232
+ data[s] = target_row[s];
216
233
  }
217
234
  }
218
235
  }
@@ -244,6 +261,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
244
261
  }
245
262
  }
246
263
 
264
+ static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
265
+ LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
266
+ const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
267
+ (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
268
+ (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
269
+ (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
270
+ LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
271
+ LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
272
+ LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
273
+
274
+ LLAMA_LOG_DEBUG(" ");
275
+ for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
276
+ LLAMA_LOG_DEBUG("%2d", j);
277
+ }
278
+ LLAMA_LOG_DEBUG("\n");
279
+
280
+ for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
281
+ LLAMA_LOG_DEBUG(" %2d ", i);
282
+ for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
283
+ float val = data[i * n_kv + j];
284
+ if (val == -INFINITY) {
285
+ LLAMA_LOG_DEBUG(" ∞");
286
+ } else {
287
+ LLAMA_LOG_DEBUG(" 0");
288
+ }
289
+ }
290
+ LLAMA_LOG_DEBUG("\n");
291
+ }
292
+ }
293
+
247
294
  void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
248
295
  const int64_t n_kv = ubatch->n_tokens;
249
296
  const int64_t n_tokens = ubatch->n_tokens;
@@ -253,6 +300,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
253
300
 
254
301
  float * data = (float *) kq_mask->data;
255
302
 
303
+ // [TAG_NO_CACHE_ISWA]
304
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
305
+
256
306
  for (int h = 0; h < 1; ++h) {
257
307
  for (int i1 = 0; i1 < n_tokens; ++i1) {
258
308
  const llama_seq_id s1 = ubatch->seq_id[i1][0];
@@ -263,37 +313,90 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
263
313
  for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
264
314
  const llama_seq_id s0 = ubatch->seq_id[i0][0];
265
315
 
316
+ if (s0 != s1) {
317
+ continue; // skip different sequences
318
+ }
319
+
320
+ if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
321
+ continue; // skip future tokens for causal attention
322
+ }
323
+
324
+ // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
325
+ //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
326
+ // continue; // skip masked tokens for SWA
327
+ //}
328
+
266
329
  // TODO: reimplement this like in llama_kv_cache_unified
267
- if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
268
- if (hparams.use_alibi) {
269
- f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
270
- } else {
271
- f = 0.0f;
272
- }
273
- break;
330
+ if (hparams.use_alibi) {
331
+ f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
332
+ } else {
333
+ f = 0.0f;
274
334
  }
275
335
  }
276
-
277
336
  data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
278
337
  }
279
338
  }
280
339
  }
340
+ if (debug) {
341
+ print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
342
+ }
281
343
  }
282
344
 
283
- void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284
- if (self_kq_mask) {
285
- mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
286
- }
345
+ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
346
+ mctx->set_input_k_idxs(self_k_idxs, ubatch);
347
+ mctx->set_input_v_idxs(self_v_idxs, ubatch);
348
+
349
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
287
350
  }
288
351
 
289
- void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
290
- if (self_kq_mask) {
291
- mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
292
- }
352
+ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
353
+ const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
293
354
 
294
- if (self_kq_mask_swa) {
295
- mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
296
- }
355
+ this->mctx = mctx;
356
+
357
+ bool res = true;
358
+
359
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
360
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
361
+
362
+ res &= self_kq_mask->ne[0] == mctx->get_n_kv();
363
+ res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
364
+
365
+ return res;
366
+ }
367
+
368
+ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
369
+ mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
370
+ mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
371
+
372
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
373
+
374
+ mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
375
+ mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
376
+
377
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
378
+ }
379
+
380
+ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
381
+ const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
382
+
383
+ this->mctx = mctx;
384
+
385
+ bool res = true;
386
+
387
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
388
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
389
+
390
+ res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
391
+ //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
392
+
393
+ res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
394
+ res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
395
+
396
+ res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
397
+ res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
398
+
399
+ return res;
297
400
  }
298
401
 
299
402
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@@ -303,7 +406,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
303
406
  const int64_t n_tokens = ubatch->n_tokens;
304
407
 
305
408
  GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
306
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
409
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
307
410
 
308
411
  float * data = (float *) cross_kq_mask->data;
309
412
 
@@ -333,27 +436,93 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
333
436
  }
334
437
 
335
438
  void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
336
- if (self_kq_mask) {
337
- mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
439
+ inp_attn->set_input(ubatch);
440
+ inp_rs->set_input(ubatch);
441
+ }
442
+
443
+ //
444
+ // llm_graph_result
445
+ //
446
+
447
+ llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
448
+ reset();
449
+
450
+ const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
451
+ debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
452
+ }
453
+
454
+ int64_t llm_graph_result::get_max_nodes() const {
455
+ return max_nodes;
456
+ }
457
+
458
+ void llm_graph_result::reset() {
459
+ t_tokens = nullptr;
460
+ t_logits = nullptr;
461
+ t_embd = nullptr;
462
+ t_embd_pooled = nullptr;
463
+
464
+ params = {};
465
+
466
+ inputs.clear();
467
+
468
+ buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
469
+
470
+ ggml_init_params params = {
471
+ /*.mem_size =*/ buf_compute_meta.size(),
472
+ /*.mem_buffer =*/ buf_compute_meta.data(),
473
+ /*.no_alloc =*/ true,
474
+ };
475
+
476
+ ctx_compute.reset(ggml_init(params));
477
+
478
+ gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
479
+ }
480
+
481
+ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
482
+ for (auto & input : inputs) {
483
+ input->set_input(ubatch);
338
484
  }
485
+ }
339
486
 
340
- const int64_t n_rs = mctx->get_recr()->get_n_rs();
487
+ bool llm_graph_result::can_reuse(const llm_graph_params & params) {
488
+ if (!this->params.allow_reuse(params)) {
489
+ if (debug > 1) {
490
+ LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
491
+ }
341
492
 
342
- if (s_copy) {
343
- GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
344
- int32_t * data = (int32_t *) s_copy->data;
493
+ return false;
494
+ }
345
495
 
346
- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
347
- for (uint32_t i = 0; i < n_rs; ++i) {
348
- data[i] = mctx->get_recr()->s_copy(i);
496
+ if (debug > 1) {
497
+ LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
498
+ }
499
+
500
+ bool res = true;
501
+
502
+ for (auto & input : inputs) {
503
+ const bool cur = input->can_reuse(params);
504
+
505
+ if (debug > 1) {
506
+ LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
349
507
  }
508
+
509
+ res = res && cur;
510
+ }
511
+
512
+ if (debug > 0) {
513
+ LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
350
514
  }
515
+
516
+ return res;
517
+ }
518
+
519
+ llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
520
+ inputs.emplace_back(std::move(input));
521
+ return inputs.back().get();
351
522
  }
352
523
 
353
- void llm_graph_input_one::set_input(const llama_ubatch *) {
354
- GGML_ASSERT(one && ggml_nelements(one) == 1);
355
- float f_one = 1.0f;
356
- ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
524
+ void llm_graph_result::set_params(const llm_graph_params & params) {
525
+ this->params = params;
357
526
  }
358
527
 
359
528
  //
@@ -390,7 +559,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
390
559
  n_ctx_orig (cparams.n_ctx_orig_yarn),
391
560
  pooling_type (cparams.pooling_type),
392
561
  rope_type (hparams.rope_type),
393
- ctx0 (params.ctx),
394
562
  sched (params.sched),
395
563
  backend_cpu (params.backend_cpu),
396
564
  cvec (params.cvec),
@@ -398,7 +566,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
398
566
  mctx (params.mctx),
399
567
  cross (params.cross),
400
568
  cb_func (params.cb),
401
- res (std::make_unique<llm_graph_result>()) {
569
+ res (params.res),
570
+ ctx0 (res->get_ctx()),
571
+ gf (res->get_gf()) {
572
+ res->set_params(params);
402
573
  }
403
574
 
404
575
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -613,6 +784,8 @@ ggml_tensor * llm_graph_context::build_ffn(
613
784
  cur = ggml_reglu(ctx0, cur);
614
785
  cb(cur, "ffn_reglu", il);
615
786
  } break;
787
+ default:
788
+ GGML_ABORT("fatal error");
616
789
  }
617
790
 
618
791
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -622,8 +795,8 @@ ggml_tensor * llm_graph_context::build_ffn(
622
795
 
623
796
  if (down) {
624
797
  cur = build_lora_mm(down, cur);
625
- if (arch == LLM_ARCH_GLM4) {
626
- // GLM4 seems to have numerical issues with half-precision accumulators
798
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
799
+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
627
800
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
628
801
  }
629
802
  }
@@ -658,13 +831,64 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
658
831
  bool scale_w,
659
832
  float w_scale,
660
833
  llama_expert_gating_func_type gating_op,
661
- int il) const {
834
+ int il,
835
+ ggml_tensor * probs_in) const {
836
+ return build_moe_ffn(
837
+ cur,
838
+ gate_inp, /* gate_inp_b */ nullptr,
839
+ up_exps, /* up_exps_b */ nullptr,
840
+ gate_exps, /* gate_exps_b */ nullptr,
841
+ down_exps, /* down_exps_b */ nullptr,
842
+ exp_probs_b,
843
+ n_expert,
844
+ n_expert_used,
845
+ type_op,
846
+ norm_w,
847
+ scale_w,
848
+ w_scale,
849
+ gating_op,
850
+ il,
851
+ probs_in
852
+ );
853
+ }
854
+
855
+ ggml_tensor * llm_graph_context::build_moe_ffn(
856
+ ggml_tensor * cur,
857
+ ggml_tensor * gate_inp,
858
+ ggml_tensor * gate_inp_b,
859
+ ggml_tensor * up_exps,
860
+ ggml_tensor * up_exps_b,
861
+ ggml_tensor * gate_exps,
862
+ ggml_tensor * gate_exps_b,
863
+ ggml_tensor * down_exps,
864
+ ggml_tensor * down_exps_b,
865
+ ggml_tensor * exp_probs_b,
866
+ int64_t n_expert,
867
+ int64_t n_expert_used,
868
+ llm_ffn_op_type type_op,
869
+ bool norm_w,
870
+ bool scale_w,
871
+ float w_scale,
872
+ llama_expert_gating_func_type gating_op,
873
+ int il,
874
+ ggml_tensor * probs_in) const {
662
875
  const int64_t n_embd = cur->ne[0];
663
876
  const int64_t n_tokens = cur->ne[1];
664
877
  const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
665
878
 
666
- ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
667
- cb(logits, "ffn_moe_logits", il);
879
+ ggml_tensor * logits = nullptr;
880
+
881
+ if (probs_in == nullptr) {
882
+ logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
883
+ cb(logits, "ffn_moe_logits", il);
884
+ } else {
885
+ logits = probs_in;
886
+ }
887
+
888
+ if (gate_inp_b) {
889
+ logits = ggml_add(ctx0, logits, gate_inp_b);
890
+ cb(logits, "ffn_moe_logits_biased", il);
891
+ }
668
892
 
669
893
  ggml_tensor * probs = nullptr;
670
894
  switch (gating_op) {
@@ -676,6 +900,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
676
900
  {
677
901
  probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
678
902
  } break;
903
+ case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
904
+ {
905
+ probs = logits; // [n_expert, n_tokens]
906
+ } break;
679
907
  default:
680
908
  GGML_ABORT("fatal error");
681
909
  }
@@ -695,15 +923,36 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
695
923
  selection_probs = logits;
696
924
  }
697
925
 
926
+ if (arch == LLM_ARCH_GROVEMOE) {
927
+ selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
928
+ cb(selection_probs, "ffn_moe_probs_biased", il);
929
+ }
930
+
698
931
  // select experts
699
932
  ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
700
933
  cb(selected_experts->src[0], "ffn_moe_argsort", il);
701
934
  cb(selected_experts, "ffn_moe_topk", il);
702
935
 
703
- ggml_tensor * weights = ggml_get_rows(ctx0,
704
- ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
936
+ if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
937
+ // TODO: Use scalar div instead when/if implemented
938
+ ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
939
+ selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
940
+ probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
941
+ } else {
942
+ probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
943
+ }
944
+
945
+ ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
705
946
  cb(weights, "ffn_moe_weights", il);
706
947
 
948
+
949
+ if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
950
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
951
+ weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
952
+ weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
953
+ cb(weights, "ffn_moe_weights_softmax", il);
954
+ }
955
+
707
956
  if (norm_w) {
708
957
  weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
709
958
 
@@ -720,6 +969,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
720
969
  cb(weights, "ffn_moe_weights_scaled", il);
721
970
  }
722
971
 
972
+ //call early so that topk-moe can be used
973
+ ggml_build_forward_expand(gf, weights);
974
+
723
975
  cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
724
976
 
725
977
  if (weight_before_ffn) {
@@ -732,6 +984,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
732
984
  ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
733
985
  cb(up, "ffn_moe_up", il);
734
986
 
987
+ if (up_exps_b) {
988
+ up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
989
+ cb(up, "ffn_moe_up_biased", il);
990
+ }
991
+
735
992
  ggml_tensor * experts = nullptr;
736
993
  if (gate_exps) {
737
994
  cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
@@ -740,6 +997,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
740
997
  cur = up;
741
998
  }
742
999
 
1000
+ if (gate_exps_b) {
1001
+ cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
1002
+ cb(cur, "ffn_moe_gate_biased", il);
1003
+ }
1004
+
743
1005
  switch (type_op) {
744
1006
  case LLM_FFN_SILU:
745
1007
  if (gate_exps) {
@@ -757,6 +1019,22 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
757
1019
  cur = ggml_gelu(ctx0, cur);
758
1020
  cb(cur, "ffn_moe_gelu", il);
759
1021
  } break;
1022
+ case LLM_FFN_SWIGLU_OAI_MOE:
1023
+ {
1024
+ // TODO: move to hparams?
1025
+ constexpr float alpha = 1.702f;
1026
+ constexpr float limit = 7.0f;
1027
+ cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
1028
+ cb(cur, "ffn_moe_swiglu_oai", il);
1029
+ } break;
1030
+ case LLM_FFN_RELU:
1031
+ if (gate_exps) {
1032
+ cur = ggml_reglu_split(ctx0, cur, up);
1033
+ cb(cur, "ffn_moe_reglu", il);
1034
+ } else {
1035
+ cur = ggml_relu(ctx0, cur);
1036
+ cb(cur, "ffn_moe_relu", il);
1037
+ } break;
760
1038
  default:
761
1039
  GGML_ABORT("fatal error");
762
1040
  }
@@ -764,25 +1042,38 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
764
1042
  experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
765
1043
  cb(experts, "ffn_moe_down", il);
766
1044
 
1045
+ if (down_exps_b) {
1046
+ experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
1047
+ cb(experts, "ffn_moe_down_biased", il);
1048
+ }
1049
+
767
1050
  if (!weight_before_ffn) {
768
1051
  experts = ggml_mul(ctx0, experts, weights);
769
1052
  cb(cur, "ffn_moe_weighted", il);
770
1053
  }
771
1054
 
1055
+ ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1056
+
1057
+ assert(n_expert_used > 0);
1058
+
1059
+ // order the views before the adds
1060
+ for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
1061
+ cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
1062
+
1063
+ ggml_build_forward_expand(gf, cur_experts[i]);
1064
+ }
1065
+
772
1066
  // aggregate experts
773
- ggml_tensor * moe_out = nullptr;
774
- for (int i = 0; i < n_expert_used; ++i) {
775
- ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
776
- experts->nb[2], i*experts->nb[1]);
1067
+ // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1068
+ // to avoid potentially a large number of add nodes during warmup
1069
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14753
1070
+ ggml_tensor * moe_out = cur_experts[0];
777
1071
 
778
- if (i == 0) {
779
- moe_out = cur_expert;
780
- } else {
781
- moe_out = ggml_add(ctx0, moe_out, cur_expert);
782
- }
1072
+ for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1073
+ moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
783
1074
  }
784
1075
 
785
- if (n_expert_used == 1) {
1076
+ if (hparams.n_expert_used == 1) {
786
1077
  // avoid returning a non-contiguous tensor
787
1078
  moe_out = ggml_cont(ctx0, moe_out);
788
1079
  }
@@ -906,7 +1197,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
906
1197
  }
907
1198
 
908
1199
  ggml_tensor * llm_graph_context::build_inp_cls() const {
909
- auto inp = std::make_unique<llm_graph_input_cls>(cparams);
1200
+ auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
910
1201
 
911
1202
  auto & cur = inp->cls;
912
1203
 
@@ -956,7 +1247,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
956
1247
  }
957
1248
 
958
1249
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
959
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1250
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
960
1251
 
961
1252
  auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
962
1253
 
@@ -987,51 +1278,28 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
987
1278
  return pos_bias;
988
1279
  }
989
1280
 
990
- llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
991
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
992
-
993
- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
994
-
995
- {
996
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
997
-
998
- const auto n_kv = inp->mctx->get_attn()->get_n_kv();
999
-
1000
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1001
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1002
- ggml_set_input(inp->self_kq_mask);
1003
-
1004
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1005
- }
1006
-
1007
- {
1008
- const auto n_rs = mctx_cur->get_recr()->get_n_rs();
1009
-
1010
- inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1011
- ggml_set_input(inp->s_copy);
1012
- }
1013
-
1014
- return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1015
- }
1016
-
1017
1281
  ggml_tensor * llm_graph_context::build_attn_mha(
1018
- ggml_cgraph * gf,
1019
1282
  ggml_tensor * q,
1020
1283
  ggml_tensor * k,
1021
1284
  ggml_tensor * v,
1022
1285
  ggml_tensor * kq_b,
1023
1286
  ggml_tensor * kq_mask,
1287
+ ggml_tensor * sinks,
1024
1288
  ggml_tensor * v_mla,
1025
- float kq_scale) const {
1289
+ float kq_scale,
1290
+ int il) const {
1026
1291
  const bool v_trans = v->nb[1] > v->nb[2];
1027
1292
 
1293
+ // split the batch into streams if needed
1294
+ const auto n_stream = k->ne[3];
1295
+
1296
+ q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
1297
+
1028
1298
  q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1029
1299
  k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1030
1300
  v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1031
1301
 
1032
- const auto n_tokens = q->ne[1];
1033
- const auto n_head = q->ne[2];
1034
- const auto n_kv = k->ne[1];
1302
+ const auto n_kv = k->ne[1];
1035
1303
 
1036
1304
  ggml_tensor * cur;
1037
1305
 
@@ -1054,8 +1322,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1054
1322
 
1055
1323
  cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1056
1324
  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1325
+ cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
1057
1326
 
1058
- ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
1327
+ ggml_flash_attn_ext_add_sinks(cur, sinks);
1328
+ ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
1059
1329
 
1060
1330
  if (v_mla) {
1061
1331
  #if 0
@@ -1068,14 +1338,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1068
1338
  // The permutations are noops and only change how the tensor data is interpreted.
1069
1339
  cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1070
1340
  cur = ggml_mul_mat(ctx0, v_mla, cur);
1341
+ cb(cur, "fattn_mla", il);
1071
1342
  cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1072
1343
  cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1073
1344
  #endif
1074
1345
  }
1075
1346
 
1076
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1347
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1077
1348
  } else {
1078
1349
  ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1350
+ cb(kq, "kq", il);
1079
1351
 
1080
1352
  // note: this op tends to require high floating point range
1081
1353
  // while for some models F16 is enough, for others it is not, so we default to F32 here
@@ -1083,42 +1355,54 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1083
1355
 
1084
1356
  if (arch == LLM_ARCH_GROK) {
1085
1357
  // need to do the following:
1086
- // multiply by attn_output_multiplyer of 0.08838834764831845
1358
+ // multiply by attn_output_multiplier
1087
1359
  // and then :
1088
1360
  // kq = 30 * tanh(kq / 30)
1089
1361
  // before the softmax below
1090
1362
 
1091
- kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
1092
- kq = ggml_scale(ctx0, kq, 30);
1363
+ kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
1364
+ cb(kq, "kq_tanh", il);
1365
+ kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1366
+ cb(kq, "kq_scaled", il);
1093
1367
  }
1094
1368
 
1095
1369
  if (hparams.attn_soft_cap) {
1096
1370
  kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
1371
+ cb(kq, "kq_scaled_1", il);
1097
1372
  kq = ggml_tanh (ctx0, kq);
1373
+ cb(kq, "kq_tanh", il);
1098
1374
  kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1375
+ cb(kq, "kq_scaled_2", il);
1099
1376
  }
1100
1377
 
1101
1378
  if (kq_b) {
1102
1379
  kq = ggml_add(ctx0, kq, kq_b);
1380
+ cb(kq, "kq_plus_kq_b", il);
1103
1381
  }
1104
1382
 
1105
1383
  kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
1384
+ ggml_soft_max_add_sinks(kq, sinks);
1385
+ cb(kq, "kq_soft_max", il);
1106
1386
 
1107
1387
  if (!v_trans) {
1108
1388
  // note: avoid this branch
1109
1389
  v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
1390
+ cb(v, "v_cont", il);
1110
1391
  }
1111
1392
 
1112
1393
  ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
1394
+ cb(kqv, "kqv", il);
1113
1395
 
1114
1396
  // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1115
1397
  if (v_mla) {
1116
1398
  kqv = ggml_mul_mat(ctx0, v_mla, kqv);
1399
+ cb(kqv, "kqv_mla", il);
1117
1400
  }
1118
1401
 
1119
1402
  cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1120
1403
 
1121
- cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1404
+ // recombine streams
1405
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1122
1406
 
1123
1407
  if (!cparams.offload_kqv) {
1124
1408
  // all nodes between the KV store and the attention output are run on the CPU
@@ -1135,8 +1419,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1135
1419
  auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1136
1420
 
1137
1421
  // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1138
- inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1139
- //cb(inp_kq_mask, "KQ_mask", -1);
1422
+ inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1140
1423
  ggml_set_input(inp->kq_mask);
1141
1424
 
1142
1425
  inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
@@ -1146,13 +1429,13 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1146
1429
 
1147
1430
  ggml_tensor * llm_graph_context::build_attn(
1148
1431
  llm_graph_input_attn_no_cache * inp,
1149
- ggml_cgraph * gf,
1150
1432
  ggml_tensor * wo,
1151
1433
  ggml_tensor * wo_b,
1152
1434
  ggml_tensor * q_cur,
1153
1435
  ggml_tensor * k_cur,
1154
1436
  ggml_tensor * v_cur,
1155
1437
  ggml_tensor * kq_b,
1438
+ ggml_tensor * sinks,
1156
1439
  ggml_tensor * v_mla,
1157
1440
  float kq_scale,
1158
1441
  int il) const {
@@ -1166,11 +1449,16 @@ ggml_tensor * llm_graph_context::build_attn(
1166
1449
 
1167
1450
  const auto & kq_mask = inp->get_kq_mask();
1168
1451
 
1452
+ // [TAG_NO_CACHE_PAD]
1453
+ // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1454
+ // but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
1455
+ //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
1456
+
1169
1457
  ggml_tensor * q = q_cur;
1170
1458
  ggml_tensor * k = k_cur;
1171
1459
  ggml_tensor * v = v_cur;
1172
1460
 
1173
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1461
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1174
1462
  cb(cur, "kqv_out", il);
1175
1463
 
1176
1464
  if (wo) {
@@ -1188,35 +1476,51 @@ ggml_tensor * llm_graph_context::build_attn(
1188
1476
  return cur;
1189
1477
  }
1190
1478
 
1191
- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1192
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1479
+ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1480
+ ggml_context * ctx0,
1481
+ const llama_ubatch & ubatch,
1482
+ const llama_hparams & hparams,
1483
+ const llama_cparams & cparams,
1484
+ const llama_kv_cache_context * mctx_cur) {
1193
1485
 
1194
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1486
+ auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
1195
1487
 
1196
1488
  {
1197
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1489
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1490
+
1491
+ const auto n_kv = mctx_cur->get_n_kv();
1492
+ const auto n_tokens = ubatch.n_tokens;
1493
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1198
1494
 
1199
- const auto n_kv = mctx_cur->get_n_kv();
1495
+ inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1496
+ inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1200
1497
 
1201
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1202
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1498
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1203
1499
  ggml_set_input(inp->self_kq_mask);
1204
1500
 
1205
1501
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1206
1502
  }
1207
1503
 
1208
- return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1504
+ return inp;
1505
+ }
1506
+
1507
+ llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
1508
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1509
+
1510
+ auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1511
+
1512
+ return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
1209
1513
  }
1210
1514
 
1211
1515
  ggml_tensor * llm_graph_context::build_attn(
1212
- llm_graph_input_attn_kv_unified * inp,
1213
- ggml_cgraph * gf,
1516
+ llm_graph_input_attn_kv * inp,
1214
1517
  ggml_tensor * wo,
1215
1518
  ggml_tensor * wo_b,
1216
1519
  ggml_tensor * q_cur,
1217
1520
  ggml_tensor * k_cur,
1218
1521
  ggml_tensor * v_cur,
1219
1522
  ggml_tensor * kq_b,
1523
+ ggml_tensor * sinks,
1220
1524
  ggml_tensor * v_mla,
1221
1525
  float kq_scale,
1222
1526
  int il) const {
@@ -1226,12 +1530,15 @@ ggml_tensor * llm_graph_context::build_attn(
1226
1530
  ggml_build_forward_expand(gf, k_cur);
1227
1531
  ggml_build_forward_expand(gf, v_cur);
1228
1532
 
1229
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1533
+ const auto * mctx_cur = inp->mctx;
1230
1534
 
1231
1535
  // store to KV cache
1232
1536
  {
1233
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1234
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1537
+ const auto & k_idxs = inp->get_k_idxs();
1538
+ const auto & v_idxs = inp->get_v_idxs();
1539
+
1540
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1541
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1235
1542
  }
1236
1543
 
1237
1544
  const auto & kq_mask = inp->get_kq_mask();
@@ -1240,13 +1547,13 @@ ggml_tensor * llm_graph_context::build_attn(
1240
1547
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1241
1548
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1242
1549
 
1243
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1550
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1244
1551
  cb(cur, "kqv_out", il);
1245
1552
 
1246
1553
  if (wo) {
1247
1554
  cur = build_lora_mm(wo, cur);
1248
- if (arch == LLM_ARCH_GLM4) {
1249
- // GLM4 seems to have numerical issues with half-precision accumulators
1555
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1556
+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1250
1557
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1251
1558
  }
1252
1559
  }
@@ -1259,14 +1566,14 @@ ggml_tensor * llm_graph_context::build_attn(
1259
1566
  }
1260
1567
 
1261
1568
  ggml_tensor * llm_graph_context::build_attn(
1262
- llm_graph_input_attn_kv_unified_iswa * inp,
1263
- ggml_cgraph * gf,
1569
+ llm_graph_input_attn_kv_iswa * inp,
1264
1570
  ggml_tensor * wo,
1265
1571
  ggml_tensor * wo_b,
1266
1572
  ggml_tensor * q_cur,
1267
1573
  ggml_tensor * k_cur,
1268
1574
  ggml_tensor * v_cur,
1269
1575
  ggml_tensor * kq_b,
1576
+ ggml_tensor * sinks,
1270
1577
  ggml_tensor * v_mla,
1271
1578
  float kq_scale,
1272
1579
  int il) const {
@@ -1282,7 +1589,7 @@ ggml_tensor * llm_graph_context::build_attn(
1282
1589
  ggml_build_forward_expand(gf, v_cur);
1283
1590
  }
1284
1591
 
1285
- const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1592
+ const auto * mctx_iswa = inp->mctx;
1286
1593
 
1287
1594
  const bool is_swa = hparams.is_swa(il);
1288
1595
 
@@ -1290,11 +1597,15 @@ ggml_tensor * llm_graph_context::build_attn(
1290
1597
 
1291
1598
  // optionally store to KV cache
1292
1599
  if (k_cur) {
1293
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1600
+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
1601
+
1602
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1294
1603
  }
1295
1604
 
1296
1605
  if (v_cur) {
1297
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1606
+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
1607
+
1608
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1298
1609
  }
1299
1610
 
1300
1611
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1303,7 +1614,7 @@ ggml_tensor * llm_graph_context::build_attn(
1303
1614
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1304
1615
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1305
1616
 
1306
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1617
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1307
1618
  cb(cur, "kqv_out", il);
1308
1619
 
1309
1620
  if (wo) {
@@ -1326,7 +1637,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1326
1637
 
1327
1638
  const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1328
1639
 
1329
- inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1640
+ inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1330
1641
  ggml_set_input(inp->cross_kq_mask);
1331
1642
 
1332
1643
  inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1336,13 +1647,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1336
1647
 
1337
1648
  ggml_tensor * llm_graph_context::build_attn(
1338
1649
  llm_graph_input_attn_cross * inp,
1339
- ggml_cgraph * gf,
1340
1650
  ggml_tensor * wo,
1341
1651
  ggml_tensor * wo_b,
1342
1652
  ggml_tensor * q_cur,
1343
1653
  ggml_tensor * k_cur,
1344
1654
  ggml_tensor * v_cur,
1345
1655
  ggml_tensor * kq_b,
1656
+ ggml_tensor * sinks,
1346
1657
  ggml_tensor * v_mla,
1347
1658
  float kq_scale,
1348
1659
  int il) const {
@@ -1358,7 +1669,7 @@ ggml_tensor * llm_graph_context::build_attn(
1358
1669
  ggml_tensor * k = k_cur;
1359
1670
  ggml_tensor * v = v_cur;
1360
1671
 
1361
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1672
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1362
1673
  cb(cur, "kqv_out", il);
1363
1674
 
1364
1675
  if (wo) {
@@ -1376,171 +1687,124 @@ ggml_tensor * llm_graph_context::build_attn(
1376
1687
  return cur;
1377
1688
  }
1378
1689
 
1379
- ggml_tensor * llm_graph_context::build_attn(
1380
- llm_graph_input_mem_hybrid * inp,
1381
- ggml_cgraph * gf,
1382
- ggml_tensor * wo,
1383
- ggml_tensor * wo_b,
1384
- ggml_tensor * q_cur,
1385
- ggml_tensor * k_cur,
1386
- ggml_tensor * v_cur,
1387
- ggml_tensor * kq_b,
1388
- ggml_tensor * v_mla,
1389
- float kq_scale,
1390
- int il) const {
1391
- // these nodes are added to the graph together so that they are not reordered
1392
- // by doing so, the number of splits in the graph is reduced
1393
- ggml_build_forward_expand(gf, q_cur);
1394
- ggml_build_forward_expand(gf, k_cur);
1395
- ggml_build_forward_expand(gf, v_cur);
1396
-
1397
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
1398
-
1399
- // store to KV cache
1400
- {
1401
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1402
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1403
- }
1404
-
1405
- const auto & kq_mask = inp->get_kq_mask();
1406
-
1407
- ggml_tensor * q = q_cur;
1408
- ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1409
- ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1410
-
1411
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1412
- cb(cur, "kqv_out", il);
1413
-
1414
- if (wo) {
1415
- cur = build_lora_mm(wo, cur);
1416
- if (arch == LLM_ARCH_GLM4) {
1417
- // GLM4 seems to have numerical issues with half-precision accumulators
1418
- ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1419
- }
1420
- }
1421
-
1422
- if (wo_b) {
1423
- cur = ggml_add(ctx0, cur, wo_b);
1424
- }
1425
-
1426
- return cur;
1427
- }
1690
+ // TODO: maybe separate the inner implementation into a separate function
1691
+ // like with the non-sliding window equivalent
1692
+ // once sliding-window hybrid caches are a thing.
1693
+ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
1694
+ const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
1428
1695
 
1429
- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1430
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1696
+ auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
1431
1697
 
1432
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1698
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1433
1699
 
1434
1700
  {
1435
1701
  const auto n_kv = mctx_cur->get_base()->get_n_kv();
1436
1702
 
1437
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1438
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1703
+ inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1704
+ inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1705
+
1706
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1439
1707
  ggml_set_input(inp->self_kq_mask);
1440
1708
 
1441
1709
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1442
1710
  }
1443
1711
 
1444
1712
  {
1445
- GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1713
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
1446
1714
 
1447
1715
  const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1448
1716
 
1449
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1450
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1717
+ inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1718
+ inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1719
+
1720
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1451
1721
  ggml_set_input(inp->self_kq_mask_swa);
1452
1722
 
1453
1723
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1454
1724
  }
1455
1725
 
1456
- return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1726
+ return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
1457
1727
  }
1458
1728
 
1459
1729
  ggml_tensor * llm_graph_context::build_rs(
1460
- ggml_cgraph * gf,
1461
1730
  ggml_tensor * s,
1462
- ggml_tensor * state_copy,
1731
+ ggml_tensor * state_copy_main,
1732
+ ggml_tensor * state_copy_extra,
1463
1733
  int32_t state_size,
1464
1734
  int32_t n_seqs,
1465
- uint32_t n_kv,
1466
- uint32_t kv_head,
1467
- uint32_t kv_size,
1735
+ uint32_t n_rs,
1736
+ uint32_t rs_head,
1737
+ uint32_t rs_size,
1468
1738
  int32_t rs_zero,
1469
- bool avoid_copies) const {
1739
+ const llm_graph_get_rows_fn & get_state_rows) const {
1470
1740
 
1471
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1741
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
1472
1742
 
1473
1743
  // Clear a single state which will then be copied to the other cleared states.
1474
1744
  // Note that this is a no-op when the view is zero-sized.
1475
1745
  ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1476
1746
  ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1477
1747
 
1478
- ggml_tensor * output_states;
1479
-
1480
- if (!avoid_copies) {
1481
- // copy states
1482
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1483
- // {state_size, kv_size} -> {state_size, n_seqs}
1484
- output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1485
- ggml_build_forward_expand(gf, output_states);
1486
- } else {
1487
- // FIXME: make the gathering operation happen before the copy below
1488
- // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1489
- output_states = states;
1490
- }
1748
+ // copy states
1749
+ // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
1750
+ // {state_size, rs_size} -> {state_size, n_seqs}
1751
+ ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
1752
+ ggml_build_forward_expand(gf, output_states);
1491
1753
 
1492
- // copy extra states which won't be changed further (between n_seqs and n_kv)
1493
- ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1754
+ // copy extra states which won't be changed further (between n_seqs and n_rs)
1755
+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
1494
1756
  ggml_build_forward_expand(gf,
1495
1757
  ggml_cpy(ctx0,
1496
1758
  states_extra,
1497
- ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1759
+ ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
1498
1760
 
1499
1761
  return output_states;
1500
1762
  }
1501
1763
 
1502
- llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1503
- const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1764
+ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
1765
+ ggml_context * ctx0,
1766
+ const llama_ubatch & ubatch,
1767
+ const llama_memory_recurrent_context * mctx_cur) {
1504
1768
 
1505
1769
  auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1506
1770
 
1507
- const auto n_rs = mctx_cur->get_n_rs();
1771
+ const int64_t n_rs = mctx_cur->get_n_rs();
1772
+ const int64_t n_seqs = ubatch.n_seqs;
1508
1773
 
1509
1774
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1510
1775
  ggml_set_input(inp->s_copy);
1511
1776
 
1512
- return (llm_graph_input_rs *) res->add_input(std::move(inp));
1777
+ inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
1778
+ inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
1779
+
1780
+ return inp;
1513
1781
  }
1514
1782
 
1515
- ggml_tensor * llm_graph_context::build_rs(
1516
- llm_graph_input_rs * inp,
1517
- ggml_cgraph * gf,
1518
- ggml_tensor * s,
1519
- int32_t state_size,
1520
- int32_t n_seqs,
1521
- bool avoid_copies) const {
1783
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1522
1784
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1523
1785
 
1524
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1786
+ auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
1787
+
1788
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1525
1789
  }
1526
1790
 
1527
1791
  ggml_tensor * llm_graph_context::build_rs(
1528
- llm_graph_input_mem_hybrid * inp,
1529
- ggml_cgraph * gf,
1792
+ llm_graph_input_rs * inp,
1530
1793
  ggml_tensor * s,
1531
1794
  int32_t state_size,
1532
1795
  int32_t n_seqs,
1533
- bool avoid_copies) const {
1534
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1796
+ const llm_graph_get_rows_fn & get_state_rows) const {
1797
+ const auto * kv_state = inp->mctx;
1535
1798
 
1536
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1799
+ return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
1800
+ kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
1801
+ get_state_rows);
1537
1802
  }
1538
1803
 
1539
1804
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1540
1805
  llm_graph_input_rs * inp,
1541
- ggml_cgraph * gf,
1542
1806
  const llama_ubatch & ubatch,
1543
- int il) const {
1807
+ int il) const {
1544
1808
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1545
1809
 
1546
1810
  const auto token_shift_count = hparams.token_shift_count;
@@ -1550,7 +1814,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1550
1814
  ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1551
1815
 
1552
1816
  ggml_tensor * token_shift = build_rs(
1553
- inp, gf, token_shift_all,
1817
+ inp, token_shift_all,
1554
1818
  hparams.n_embd_r(), n_seqs);
1555
1819
 
1556
1820
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1578,8 +1842,18 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1578
1842
  );
1579
1843
  }
1580
1844
 
1845
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1846
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1847
+
1848
+ auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
1849
+ auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1850
+
1851
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
1852
+
1853
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1854
+ }
1855
+
1581
1856
  void llm_graph_context::build_pooling(
1582
- ggml_cgraph * gf,
1583
1857
  ggml_tensor * cls,
1584
1858
  ggml_tensor * cls_b,
1585
1859
  ggml_tensor * cls_out,
@@ -1623,34 +1897,32 @@ void llm_graph_context::build_pooling(
1623
1897
  case LLAMA_POOLING_TYPE_RANK:
1624
1898
  {
1625
1899
  ggml_tensor * inp_cls = build_inp_cls();
1626
- inp = ggml_get_rows(ctx0, inp, inp_cls);
1900
+ cur = ggml_get_rows(ctx0, inp, inp_cls);
1627
1901
 
1902
+ // classification head
1903
+ // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1628
1904
  if (cls) {
1629
- // classification head
1630
- // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1631
- cur = ggml_mul_mat(ctx0, cls, inp);
1905
+ cur = ggml_mul_mat(ctx0, cls, cur);
1632
1906
  if (cls_b) {
1633
1907
  cur = ggml_add(ctx0, cur, cls_b);
1634
1908
  }
1635
1909
  cur = ggml_tanh(ctx0, cur);
1910
+ }
1636
1911
 
1637
- // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1638
- // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1639
- if (cls_out) {
1640
- cur = ggml_mul_mat(ctx0, cls_out, cur);
1641
- if (cls_out_b) {
1642
- cur = ggml_add(ctx0, cur, cls_out_b);
1643
- }
1644
- }
1645
- } else if (cls_out) {
1646
- // Single layer classification head (direct projection)
1647
- // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1648
- cur = ggml_mul_mat(ctx0, cls_out, inp);
1912
+ // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1913
+ // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1914
+ // Single layer classification head (direct projection)
1915
+ // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1916
+ if (cls_out) {
1917
+ cur = ggml_mul_mat(ctx0, cls_out, cur);
1649
1918
  if (cls_out_b) {
1650
1919
  cur = ggml_add(ctx0, cur, cls_out_b);
1651
1920
  }
1652
- } else {
1653
- GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
1921
+ }
1922
+
1923
+ // softmax for qwen3 reranker
1924
+ if (arch == LLM_ARCH_QWEN3) {
1925
+ cur = ggml_soft_max(ctx0, cur);
1654
1926
  }
1655
1927
  } break;
1656
1928
  default: