whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -7,11 +7,50 @@
7
7
  #include "llama-kv-cache.h"
8
8
  #include "llama-kv-cache-iswa.h"
9
9
  #include "llama-memory-hybrid.h"
10
+ #include "llama-memory-hybrid-iswa.h"
10
11
  #include "llama-memory-recurrent.h"
11
12
 
12
13
  #include <cassert>
13
14
  #include <cmath>
14
15
  #include <cstring>
16
+ #include <numeric>
17
+ #include <sstream>
18
+ #include <unordered_set>
19
+
20
+ // dedup helpers
21
+
22
+ static ggml_tensor * build_kq_mask(
23
+ ggml_context * ctx,
24
+ const llama_kv_cache_context * mctx,
25
+ const llama_ubatch & ubatch,
26
+ const llama_cparams & cparams) {
27
+ const auto n_kv = mctx->get_n_kv();
28
+ const auto n_tokens = ubatch.n_tokens;
29
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
30
+
31
+ return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
32
+ }
33
+
34
+ static bool can_reuse_kq_mask(
35
+ ggml_tensor * kq_mask,
36
+ const llama_kv_cache_context * mctx,
37
+ const llama_ubatch & ubatch,
38
+ const llama_cparams & cparams) {
39
+ const auto n_kv = mctx->get_n_kv();
40
+ const auto n_tokens = ubatch.n_tokens;
41
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
42
+
43
+ bool res = true;
44
+
45
+ res &= (kq_mask->ne[0] == n_kv);
46
+ res &= (kq_mask->ne[1] == n_tokens/n_stream);
47
+ res &= (kq_mask->ne[2] == 1);
48
+ res &= (kq_mask->ne[3] == n_stream);
49
+
50
+ return res;
51
+ }
52
+
53
+ // impl
15
54
 
16
55
  void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
17
56
  if (ubatch->token) {
@@ -21,7 +60,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
21
60
  }
22
61
 
23
62
  if (ubatch->embd) {
24
- const int64_t n_embd = embd->ne[0];
63
+ GGML_ASSERT(n_embd == embd->ne[0]);
64
+
25
65
  const int64_t n_tokens = ubatch->n_tokens;
26
66
 
27
67
  ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
@@ -31,8 +71,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
31
71
  bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
32
72
  bool res = true;
33
73
 
34
- res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
35
- res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
74
+ res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
75
+ res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
36
76
 
37
77
  return res;
38
78
  }
@@ -62,7 +102,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
62
102
  bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
63
103
  bool res = true;
64
104
 
65
- res &= pos->ne[0] == params.ubatch.n_tokens;
105
+ res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
66
106
 
67
107
  return res;
68
108
  }
@@ -71,11 +111,14 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
71
111
  if (ubatch->pos && attn_scale) {
72
112
  const int64_t n_tokens = ubatch->n_tokens;
73
113
 
114
+ GGML_ASSERT(f_attn_temp_scale != 0.0f);
115
+ GGML_ASSERT(n_attn_temp_floor_scale != 0);
116
+
74
117
  std::vector<float> attn_scale_data(n_tokens, 0.0f);
75
118
  for (int i = 0; i < n_tokens; ++i) {
76
119
  const float pos = ubatch->pos[i];
77
120
  attn_scale_data[i] = std::log(
78
- std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
121
+ std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
79
122
  ) * f_attn_temp_scale + 1.0;
80
123
  }
81
124
 
@@ -92,11 +135,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
92
135
 
93
136
  int32_t * data = (int32_t *) pos_bucket->data;
94
137
 
95
- for (int h = 0; h < 1; ++h) {
96
- for (int j = 0; j < n_tokens; ++j) {
97
- for (int i = 0; i < n_tokens; ++i) {
98
- data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
99
- }
138
+ for (int j = 0; j < n_tokens; ++j) {
139
+ for (int i = 0; i < n_tokens; ++i) {
140
+ data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
100
141
  }
101
142
  }
102
143
  }
@@ -144,7 +185,10 @@ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
144
185
  }
145
186
 
146
187
  void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
147
- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
188
+ if (cparams.embeddings &&
189
+ (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN ||
190
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) {
191
+
148
192
  const int64_t n_tokens = ubatch->n_tokens;
149
193
  const int64_t n_seq_tokens = ubatch->n_seq_tokens;
150
194
  const int64_t n_seqs_unq = ubatch->n_seqs_unq;
@@ -206,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
206
250
 
207
251
  const bool last = (
208
252
  cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
209
- (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
253
+ (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
210
254
  );
211
255
 
212
256
  for (int i = 0; i < n_tokens; ++i) {
@@ -251,6 +295,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
251
295
  }
252
296
  }
253
297
 
298
+ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
299
+ const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
300
+
301
+ this->mctx = mctx;
302
+
303
+ bool res = true;
304
+
305
+ res &= s_copy->ne[0] == mctx->get_n_rs();
306
+
307
+ res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
308
+ res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
309
+
310
+ res &= head == mctx->get_head();
311
+ res &= rs_z == mctx->get_rs_z();
312
+
313
+ return res;
314
+ }
315
+
254
316
  void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
255
317
  GGML_UNUSED(ubatch);
256
318
 
@@ -261,12 +323,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261
323
  }
262
324
  }
263
325
 
264
- static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
326
+ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
265
327
  LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
266
- const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
267
- (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
268
- (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
269
- (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
328
+ const char * swa_type_str = "unknown";
329
+
330
+ switch (swa_type) {
331
+ case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
332
+ case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
333
+ case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
334
+ case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
335
+ };
336
+
270
337
  LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
271
338
  LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
272
339
  LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@@ -295,50 +362,65 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
295
362
  const int64_t n_kv = ubatch->n_tokens;
296
363
  const int64_t n_tokens = ubatch->n_tokens;
297
364
 
298
- GGML_ASSERT(kq_mask);
299
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
300
-
301
- float * data = (float *) kq_mask->data;
302
-
303
- // [TAG_NO_CACHE_ISWA]
304
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
305
-
306
- for (int h = 0; h < 1; ++h) {
365
+ const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
307
366
  for (int i1 = 0; i1 < n_tokens; ++i1) {
308
367
  const llama_seq_id s1 = ubatch->seq_id[i1][0];
368
+ const llama_pos p1 = ubatch->pos[i1];
309
369
 
310
- for (int i0 = 0; i0 < n_tokens; ++i0) {
311
- float f = -INFINITY;
312
-
313
- for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
314
- const llama_seq_id s0 = ubatch->seq_id[i0][0];
370
+ const uint64_t idst = i1*n_kv;
315
371
 
316
- if (s0 != s1) {
317
- continue; // skip different sequences
318
- }
372
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
373
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
374
+ const llama_pos p0 = ubatch->pos[i0];
319
375
 
320
- if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
321
- continue; // skip future tokens for causal attention
322
- }
376
+ // mask different sequences
377
+ if (s0 != s1) {
378
+ continue;
379
+ }
323
380
 
324
- // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
325
- //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
326
- // continue; // skip masked tokens for SWA
327
- //}
381
+ // mask future tokens
382
+ if (cparams.causal_attn && p0 > p1) {
383
+ continue;
384
+ }
328
385
 
329
- // TODO: reimplement this like in llama_kv_cache_unified
330
- if (hparams.use_alibi) {
331
- f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
332
- } else {
333
- f = 0.0f;
334
- }
386
+ // apply SWA if any
387
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
388
+ continue;
335
389
  }
336
- data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
390
+
391
+ data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
337
392
  }
338
393
  }
394
+ };
395
+
396
+ {
397
+ GGML_ASSERT(self_kq_mask);
398
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
399
+
400
+ float * data = (float *) self_kq_mask->data;
401
+
402
+ std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
403
+
404
+ fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
405
+
406
+ if (debug) {
407
+ print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
408
+ }
339
409
  }
340
- if (debug) {
341
- print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
410
+
411
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
412
+ GGML_ASSERT(self_kq_mask_swa);
413
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
414
+
415
+ float * data = (float *) self_kq_mask_swa->data;
416
+
417
+ std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
418
+
419
+ fill_mask(data, hparams.n_swa, hparams.swa_type);
420
+
421
+ if (debug) {
422
+ print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
423
+ }
342
424
  }
343
425
  }
344
426
 
@@ -359,8 +441,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
359
441
  res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
360
442
  //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
361
443
 
362
- res &= self_kq_mask->ne[0] == mctx->get_n_kv();
363
- res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
444
+ res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
445
+
446
+ return res;
447
+ }
448
+
449
+ void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
450
+ mctx->set_input_k_idxs(self_k_idxs, ubatch);
451
+
452
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
453
+ }
454
+
455
+ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
456
+ const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
457
+
458
+ this->mctx = mctx;
459
+
460
+ bool res = true;
461
+
462
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
463
+
464
+ res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
364
465
 
365
466
  return res;
366
467
  }
@@ -390,11 +491,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
390
491
  res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
391
492
  //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
392
493
 
393
- res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
394
- res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
395
-
396
- res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
397
- res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
494
+ res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
495
+ res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
398
496
 
399
497
  return res;
400
498
  }
@@ -410,34 +508,212 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
410
508
 
411
509
  float * data = (float *) cross_kq_mask->data;
412
510
 
413
- for (int h = 0; h < 1; ++h) {
414
- for (int i = 0; i < n_tokens; ++i) {
415
- for (int j = 0; j < n_enc; ++j) {
416
- float f = -INFINITY;
511
+ for (int i = 0; i < n_tokens; ++i) {
512
+ GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first");
513
+ for (int j = 0; j < n_enc; ++j) {
514
+ float f = -INFINITY;
417
515
 
418
- for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
419
- const llama_seq_id seq_id = ubatch->seq_id[i][s];
516
+ for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
517
+ const llama_seq_id seq_id = ubatch->seq_id[i][s];
420
518
 
421
- if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
422
- f = 0.0f;
423
- }
519
+ if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
520
+ f = 0.0f;
424
521
  }
425
-
426
- data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
427
522
  }
428
- }
429
523
 
430
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
431
- for (int j = 0; j < n_enc; ++j) {
432
- data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
433
- }
524
+ data[i*n_enc + j] = f;
434
525
  }
435
526
  }
436
527
  }
437
528
 
438
529
  void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
439
- inp_attn->set_input(ubatch);
440
- inp_rs->set_input(ubatch);
530
+ mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
531
+ mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
532
+
533
+ mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
534
+
535
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
536
+
537
+ if (inp_rs->s_copy) {
538
+ GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
539
+ int32_t * data = (int32_t *) inp_rs->s_copy->data;
540
+
541
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
542
+ for (uint32_t i = 0; i < n_rs; ++i) {
543
+ data[i] = mctx->get_recr()->s_copy(i);
544
+ }
545
+ }
546
+ }
547
+
548
+ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
549
+ const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
550
+
551
+ this->mctx = mctx;
552
+
553
+ bool res = true;
554
+
555
+ res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
556
+ //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
557
+
558
+ res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
559
+
560
+ res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
561
+
562
+ res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
563
+ res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
564
+
565
+ res &= inp_rs->head == mctx->get_recr()->get_head();
566
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
567
+
568
+ return res;
569
+ }
570
+
571
+ // TODO: Hybrid input classes are a bit redundant.
572
+ // Instead of creating a hybrid input, the graph can simply create 2 separate inputs.
573
+ // Refactoring is required in the future.
574
+ void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) {
575
+ mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
576
+
577
+ mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
578
+
579
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
580
+
581
+ if (inp_rs->s_copy) {
582
+ GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
583
+ int32_t * data = (int32_t *) inp_rs->s_copy->data;
584
+
585
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
586
+ for (uint32_t i = 0; i < n_rs; ++i) {
587
+ data[i] = mctx->get_recr()->s_copy(i);
588
+ }
589
+ }
590
+ }
591
+
592
+ bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
593
+ const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
594
+
595
+ this->mctx = mctx;
596
+
597
+ bool res = true;
598
+
599
+ res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
600
+
601
+ res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
602
+
603
+ res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
604
+
605
+ res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
606
+ res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
607
+
608
+ res &= inp_rs->head == mctx->get_recr()->get_head();
609
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
610
+
611
+ return res;
612
+ }
613
+
614
+ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
615
+ const auto * attn_ctx = mctx->get_attn();
616
+
617
+ // base tensors may not be allocated if there are no non-SWA attention layers
618
+ if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
619
+ attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
620
+ attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
621
+
622
+ attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
623
+ }
624
+
625
+ // swa tensors may not be allocated if there are no SWA attention layers
626
+ if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
627
+ attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
628
+ attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
629
+
630
+ attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
631
+ }
632
+
633
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
634
+
635
+ if (inp_rs->s_copy) {
636
+ GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
637
+ int32_t * data = (int32_t *) inp_rs->s_copy->data;
638
+
639
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
640
+ for (uint32_t i = 0; i < n_rs; ++i) {
641
+ data[i] = mctx->get_recr()->s_copy(i);
642
+ }
643
+ }
644
+ }
645
+
646
+ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
647
+ const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
648
+
649
+ this->mctx = mctx;
650
+
651
+ bool res = true;
652
+
653
+ const auto * attn_ctx = mctx->get_attn();
654
+
655
+ // base tensors may not be allocated if there are no non-SWA attention layers
656
+ if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
657
+ res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
658
+ //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
659
+
660
+ res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
661
+ }
662
+
663
+ // swa tensors may not be allocated if there are no SWA attention layers
664
+ if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
665
+ res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
666
+ //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
667
+
668
+ res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
669
+ }
670
+
671
+ res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
672
+
673
+ res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
674
+ res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
675
+
676
+ res &= inp_rs->head == mctx->get_recr()->get_head();
677
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
678
+
679
+ return res;
680
+ }
681
+
682
+ void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
683
+ // set the inputs only for the active samplers in the current ubatch
684
+ std::unordered_set<llama_seq_id> active_samplers;
685
+ for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
686
+ if (ubatch->output[i]) {
687
+ llama_seq_id seq_id = ubatch->seq_id[i][0];
688
+ active_samplers.insert(seq_id);
689
+ }
690
+ }
691
+
692
+ for (auto seq_id : active_samplers) {
693
+ if (samplers.find(seq_id) == samplers.end()) {
694
+ continue;
695
+ }
696
+
697
+ auto & sampler = samplers[seq_id];
698
+
699
+ if (sampler->iface->backend_set_input) {
700
+ sampler->iface->backend_set_input(sampler);
701
+ }
702
+ }
703
+ }
704
+
705
+ bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
706
+ if (samplers.size() != params.samplers.size()) {
707
+ return false;
708
+ }
709
+
710
+ for (const auto & [seq_id, sampler] : params.samplers) {
711
+ if (samplers[seq_id] != sampler) {
712
+ return false;
713
+ }
714
+ }
715
+
716
+ return true;
441
717
  }
442
718
 
443
719
  //
@@ -456,10 +732,15 @@ int64_t llm_graph_result::get_max_nodes() const {
456
732
  }
457
733
 
458
734
  void llm_graph_result::reset() {
459
- t_tokens = nullptr;
735
+ t_inp_tokens = nullptr;
736
+ t_inp_embd = nullptr;
460
737
  t_logits = nullptr;
461
738
  t_embd = nullptr;
462
739
  t_embd_pooled = nullptr;
740
+ t_sampled.clear();
741
+ t_sampled_probs.clear();
742
+ t_sampled_logits.clear();
743
+ t_candidates.clear();
463
744
 
464
745
  params = {};
465
746
 
@@ -484,6 +765,38 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
484
765
  }
485
766
  }
486
767
 
768
+ void llm_graph_result::set_outputs() {
769
+ if (t_logits != nullptr) {
770
+ ggml_set_output(t_logits);
771
+ }
772
+ if (t_embd != nullptr) {
773
+ ggml_set_output(t_embd);
774
+ }
775
+ if (t_embd_pooled != nullptr) {
776
+ ggml_set_output(t_embd_pooled);
777
+ }
778
+ for (auto & [seq_id, t] : t_sampled) {
779
+ if (t != nullptr) {
780
+ ggml_set_output(t);
781
+ }
782
+ }
783
+ for (auto & [seq_id, t] : t_sampled_probs) {
784
+ if (t != nullptr) {
785
+ ggml_set_output(t);
786
+ }
787
+ }
788
+ for (auto & [seq_id, t] : t_sampled_logits) {
789
+ if (t != nullptr) {
790
+ ggml_set_output(t);
791
+ }
792
+ }
793
+ for (auto & [seq_id, t] : t_candidates) {
794
+ if (t != nullptr) {
795
+ ggml_set_output(t);
796
+ }
797
+ }
798
+ }
799
+
487
800
  bool llm_graph_result::can_reuse(const llm_graph_params & params) {
488
801
  if (!this->params.allow_reuse(params)) {
489
802
  if (debug > 1) {
@@ -536,13 +849,13 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
536
849
  ubatch (params.ubatch),
537
850
  n_embd (hparams.n_embd),
538
851
  n_layer (hparams.n_layer),
539
- n_rot (hparams.n_rot),
852
+ n_rot (hparams.n_rot()),
540
853
  n_ctx (cparams.n_ctx),
541
854
  n_head (hparams.n_head()),
542
855
  n_head_kv (hparams.n_head_kv()),
543
- n_embd_head_k (hparams.n_embd_head_k),
856
+ n_embd_head_k (hparams.n_embd_head_k()),
544
857
  n_embd_k_gqa (hparams.n_embd_k_gqa()),
545
- n_embd_head_v (hparams.n_embd_head_v),
858
+ n_embd_head_v (hparams.n_embd_head_v()),
546
859
  n_embd_v_gqa (hparams.n_embd_v_gqa()),
547
860
  n_expert (hparams.n_expert),
548
861
  n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
@@ -565,6 +878,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
565
878
  loras (params.loras),
566
879
  mctx (params.mctx),
567
880
  cross (params.cross),
881
+ samplers (params.samplers),
568
882
  cb_func (params.cb),
569
883
  res (params.res),
570
884
  ctx0 (res->get_ctx()),
@@ -586,7 +900,8 @@ ggml_tensor * llm_graph_context::build_cvec(
586
900
 
587
901
  ggml_tensor * llm_graph_context::build_lora_mm(
588
902
  ggml_tensor * w,
589
- ggml_tensor * cur) const {
903
+ ggml_tensor * cur,
904
+ ggml_tensor * w_s) const {
590
905
  ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
591
906
 
592
907
  for (const auto & lora : *loras) {
@@ -607,6 +922,10 @@ ggml_tensor * llm_graph_context::build_lora_mm(
607
922
  res = ggml_add(ctx0, res, ab_cur);
608
923
  }
609
924
 
925
+ if (w_s) {
926
+ res = ggml_mul(ctx0, res, w_s);
927
+ }
928
+
610
929
  return res;
611
930
  }
612
931
 
@@ -732,6 +1051,26 @@ ggml_tensor * llm_graph_context::build_ffn(
732
1051
  switch (type_op) {
733
1052
  case LLM_FFN_SILU:
734
1053
  if (gate && type_gate == LLM_FFN_PAR) {
1054
+ // Step35: HF clamps gate (after SiLU) and up before multiplication
1055
+ if (arch == LLM_ARCH_STEP35 && il >= 0) {
1056
+ const float limit = hparams.swiglu_clamp_shexp[il];
1057
+ constexpr float eps = 1e-6f;
1058
+ if (limit > eps) {
1059
+ ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1060
+ cb(gate_act, "ffn_silu", il);
1061
+ gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1062
+ cb(gate_act, "ffn_silu_clamped", il);
1063
+
1064
+ tmp = ggml_clamp(ctx0, tmp, -limit, limit);
1065
+ cb(tmp, "ffn_up_clamped", il);
1066
+
1067
+ cur = ggml_mul(ctx0, gate_act, tmp);
1068
+ cb(cur, "ffn_swiglu_limited", il);
1069
+ type_gate = LLM_FFN_SEQ;
1070
+ break;
1071
+ }
1072
+ }
1073
+
735
1074
  cur = ggml_swiglu_split(ctx0, cur, tmp);
736
1075
  cb(cur, "ffn_swiglu", il);
737
1076
  type_gate = LLM_FFN_SEQ;
@@ -795,8 +1134,8 @@ ggml_tensor * llm_graph_context::build_ffn(
795
1134
 
796
1135
  if (down) {
797
1136
  cur = build_lora_mm(down, cur);
798
- if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
799
- // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1137
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
1138
+ // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
800
1139
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
801
1140
  }
802
1141
  }
@@ -828,11 +1167,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
828
1167
  int64_t n_expert_used,
829
1168
  llm_ffn_op_type type_op,
830
1169
  bool norm_w,
831
- bool scale_w,
832
1170
  float w_scale,
833
1171
  llama_expert_gating_func_type gating_op,
834
1172
  int il,
835
- ggml_tensor * probs_in) const {
1173
+ ggml_tensor * probs_in,
1174
+ ggml_tensor * gate_up_exps,
1175
+ ggml_tensor * up_exps_s,
1176
+ ggml_tensor * gate_exps_s,
1177
+ ggml_tensor * down_exps_s) const {
836
1178
  return build_moe_ffn(
837
1179
  cur,
838
1180
  gate_inp, /* gate_inp_b */ nullptr,
@@ -844,11 +1186,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
844
1186
  n_expert_used,
845
1187
  type_op,
846
1188
  norm_w,
847
- scale_w,
848
1189
  w_scale,
849
1190
  gating_op,
850
1191
  il,
851
- probs_in
1192
+ probs_in,
1193
+ gate_up_exps,
1194
+ /* gate_up_exps_b */ nullptr,
1195
+ up_exps_s,
1196
+ gate_exps_s,
1197
+ down_exps_s
852
1198
  );
853
1199
  }
854
1200
 
@@ -867,11 +1213,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
867
1213
  int64_t n_expert_used,
868
1214
  llm_ffn_op_type type_op,
869
1215
  bool norm_w,
870
- bool scale_w,
871
1216
  float w_scale,
872
1217
  llama_expert_gating_func_type gating_op,
873
1218
  int il,
874
- ggml_tensor * probs_in) const {
1219
+ ggml_tensor * probs_in,
1220
+ ggml_tensor * gate_up_exps,
1221
+ ggml_tensor * gate_up_exps_b,
1222
+ ggml_tensor * up_exps_s,
1223
+ ggml_tensor * gate_exps_s,
1224
+ ggml_tensor * down_exps_s) const {
875
1225
  const int64_t n_embd = cur->ne[0];
876
1226
  const int64_t n_tokens = cur->ne[1];
877
1227
  const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
@@ -928,8 +1278,33 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
928
1278
  cb(selection_probs, "ffn_moe_probs_biased", il);
929
1279
  }
930
1280
 
1281
+ // select top n_group_used expert groups
1282
+ // https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
1283
+ if (hparams.n_expert_groups > 1 && n_tokens > 0) {
1284
+ const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
1285
+
1286
+ // organize experts into n_expert_groups
1287
+ ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
1288
+
1289
+ ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
1290
+ group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
1291
+
1292
+ // get top n_group_used expert groups
1293
+ group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
1294
+ group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
1295
+
1296
+ ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
1297
+ cb(expert_groups, "ffn_moe_group_topk", il);
1298
+
1299
+ // mask out the other groups
1300
+ selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
1301
+ selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
1302
+ selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
1303
+ cb(selection_probs, "ffn_moe_probs_masked", il);
1304
+ }
1305
+
931
1306
  // select experts
932
- ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
1307
+ ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
933
1308
  cb(selected_experts->src[0], "ffn_moe_argsort", il);
934
1309
  cb(selected_experts, "ffn_moe_topk", il);
935
1310
 
@@ -959,12 +1334,16 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
959
1334
  ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
960
1335
  cb(weights_sum, "ffn_moe_weights_sum", il);
961
1336
 
1337
+ // Avoid division by zero, clamp to smallest number representable by F16
1338
+ weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
1339
+ cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
1340
+
962
1341
  weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
963
1342
  cb(weights, "ffn_moe_weights_norm", il);
964
1343
 
965
1344
  weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
966
1345
  }
967
- if (scale_w) {
1346
+ if (w_scale != 0.0f && w_scale != 1.0f) {
968
1347
  weights = ggml_scale(ctx0, weights, w_scale);
969
1348
  cb(weights, "ffn_moe_weights_scaled", il);
970
1349
  }
@@ -981,30 +1360,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
981
1360
  cb(cur, "ffn_moe_weighted", il);
982
1361
  }
983
1362
 
984
- ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
985
- cb(up, "ffn_moe_up", il);
1363
+ ggml_tensor * up = nullptr;
1364
+ ggml_tensor * experts = nullptr;
986
1365
 
987
- if (up_exps_b) {
988
- up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
989
- cb(up, "ffn_moe_up_biased", il);
990
- }
1366
+ if (gate_up_exps) {
1367
+ // merged gate_up path: one mul_mat_id, then split into gate and up views
1368
+ ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens]
1369
+ cb(gate_up, "ffn_moe_gate_up", il);
991
1370
 
992
- ggml_tensor * experts = nullptr;
993
- if (gate_exps) {
994
- cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1371
+ if (gate_up_exps_b) {
1372
+ gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts);
1373
+ cb(gate_up, "ffn_moe_gate_up_biased", il);
1374
+ }
1375
+
1376
+ // apply per-expert scale2 to merged gate_up (use up_exps_s since gate and up are fused)
1377
+ if (up_exps_s) {
1378
+ ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
1379
+ s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1380
+ s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1381
+ gate_up = ggml_mul(ctx0, gate_up, s);
1382
+ cb(gate_up, "ffn_moe_gate_up_scaled", il);
1383
+ }
1384
+
1385
+ const int64_t n_ff = gate_up->ne[0] / 2;
1386
+ cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0);
995
1387
  cb(cur, "ffn_moe_gate", il);
1388
+ up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]);
1389
+ cb(up, "ffn_moe_up", il);
996
1390
  } else {
997
- cur = up;
998
- }
1391
+ // separate gate and up path
1392
+ up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1393
+ cb(up, "ffn_moe_up", il);
1394
+
1395
+ if (up_exps_b) {
1396
+ up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
1397
+ cb(up, "ffn_moe_up_biased", il);
1398
+ }
1399
+
1400
+ // apply per-expert scale2 to up
1401
+ if (up_exps_s) {
1402
+ ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
1403
+ s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1404
+ s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1405
+ up = ggml_mul(ctx0, up, s);
1406
+ cb(up, "ffn_moe_up_scaled", il);
1407
+ }
1408
+
1409
+ if (gate_exps) {
1410
+ cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1411
+ cb(cur, "ffn_moe_gate", il);
1412
+ } else {
1413
+ cur = up;
1414
+ }
1415
+
1416
+ if (gate_exps_b) {
1417
+ cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
1418
+ cb(cur, "ffn_moe_gate_biased", il);
1419
+ }
999
1420
 
1000
- if (gate_exps_b) {
1001
- cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
1002
- cb(cur, "ffn_moe_gate_biased", il);
1421
+ // apply per-expert scale2 to gate
1422
+ if (gate_exps_s) {
1423
+ ggml_tensor * s = ggml_reshape_3d(ctx0, gate_exps_s, 1, n_expert, 1);
1424
+ s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1425
+ s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1426
+ cur = ggml_mul(ctx0, cur, s);
1427
+ cb(cur, "ffn_moe_gate_scaled", il);
1428
+ }
1003
1429
  }
1004
1430
 
1431
+ const bool has_gate = gate_exps || gate_up_exps;
1432
+
1005
1433
  switch (type_op) {
1006
1434
  case LLM_FFN_SILU:
1007
1435
  if (gate_exps) {
1436
+ // Step35: per-layer clamp for routed experts
1437
+ if (arch == LLM_ARCH_STEP35 && il >= 0) {
1438
+ const float limit = hparams.swiglu_clamp_exp[il];
1439
+ constexpr float eps = 1e-6f;
1440
+ if (limit > eps) {
1441
+ ggml_tensor * gate_act = ggml_silu(ctx0, cur);
1442
+ cb(gate_act, "ffn_moe_silu", il);
1443
+ gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
1444
+ cb(gate_act, "ffn_moe_silu_clamped", il);
1445
+
1446
+ up = ggml_clamp(ctx0, up, -limit, limit);
1447
+ cb(up, "ffn_moe_up_clamped", il);
1448
+
1449
+ cur = ggml_mul(ctx0, gate_act, up);
1450
+ cb(cur, "ffn_moe_swiglu_limited", il);
1451
+ break;
1452
+ }
1453
+ }
1454
+ }
1455
+
1456
+ if (has_gate) {
1008
1457
  cur = ggml_swiglu_split(ctx0, cur, up);
1009
1458
  cb(cur, "ffn_moe_swiglu", il);
1010
1459
  } else {
@@ -1012,7 +1461,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1012
1461
  cb(cur, "ffn_moe_silu", il);
1013
1462
  } break;
1014
1463
  case LLM_FFN_GELU:
1015
- if (gate_exps) {
1464
+ if (has_gate) {
1016
1465
  cur = ggml_geglu_split(ctx0, cur, up);
1017
1466
  cb(cur, "ffn_moe_geglu", il);
1018
1467
  } else {
@@ -1028,13 +1477,22 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1028
1477
  cb(cur, "ffn_moe_swiglu_oai", il);
1029
1478
  } break;
1030
1479
  case LLM_FFN_RELU:
1031
- if (gate_exps) {
1480
+ if (has_gate) {
1032
1481
  cur = ggml_reglu_split(ctx0, cur, up);
1033
1482
  cb(cur, "ffn_moe_reglu", il);
1034
1483
  } else {
1035
1484
  cur = ggml_relu(ctx0, cur);
1036
1485
  cb(cur, "ffn_moe_relu", il);
1037
1486
  } break;
1487
+ case LLM_FFN_RELU_SQR:
1488
+ if (has_gate) {
1489
+ // TODO: add support for gated squared relu
1490
+ GGML_ABORT("fatal error: gated squared relu not implemented");
1491
+ } else {
1492
+ cur = ggml_relu(ctx0, cur);
1493
+ cur = ggml_sqr(ctx0, cur);
1494
+ cb(cur, "ffn_moe_relu_sqr", il);
1495
+ } break;
1038
1496
  default:
1039
1497
  GGML_ABORT("fatal error");
1040
1498
  }
@@ -1047,6 +1505,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1047
1505
  cb(experts, "ffn_moe_down_biased", il);
1048
1506
  }
1049
1507
 
1508
+ // apply per-expert scale2 to down
1509
+ if (down_exps_s) {
1510
+ ggml_tensor * s = ggml_reshape_3d(ctx0, down_exps_s, 1, n_expert, 1);
1511
+ s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
1512
+ s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
1513
+ experts = ggml_mul(ctx0, experts, s);
1514
+ cb(experts, "ffn_moe_down_scaled", il);
1515
+ }
1516
+
1050
1517
  if (!weight_before_ffn) {
1051
1518
  experts = ggml_mul(ctx0, experts, weights);
1052
1519
  cb(cur, "ffn_moe_weighted", il);
@@ -1085,17 +1552,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1085
1552
 
1086
1553
  // input embeddings with optional lora
1087
1554
  ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1088
- const int64_t n_embd = hparams.n_embd;
1555
+ const int64_t n_embd_inp = hparams.n_embd_inp();
1556
+ const int64_t n_embd = hparams.n_embd;
1557
+
1558
+ assert(n_embd_inp >= n_embd);
1559
+
1560
+ auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
1089
1561
 
1090
- auto inp = std::make_unique<llm_graph_input_embd>();
1562
+ inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1563
+ cb(inp->tokens, "inp_tokens", -1);
1564
+ ggml_set_input(inp->tokens);
1565
+ res->t_inp_tokens = inp->tokens;
1091
1566
 
1092
- ggml_tensor * cur = nullptr;
1567
+ inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
1568
+ cb(inp->embd, "inp_embd", -1);
1569
+ ggml_set_input(inp->embd);
1093
1570
 
1094
- if (ubatch.token) {
1095
- inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
1096
- //cb(inp->tokens, "inp_tokens", -1);
1097
- ggml_set_input(inp->tokens);
1098
- res->t_tokens = inp->tokens;
1571
+ // select one of the 2 inputs, based on the batch contents
1572
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18550
1573
+ std::array<ggml_tensor *, 2> inps;
1574
+
1575
+ // token embeddings path (ubatch.token != nullptr)
1576
+ {
1577
+ auto & cur = inps[0];
1099
1578
 
1100
1579
  cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
1101
1580
 
@@ -1116,22 +1595,43 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1116
1595
 
1117
1596
  cur = ggml_add(ctx0, cur, inpL_delta);
1118
1597
  }
1119
- } else {
1120
- inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
1121
- ggml_set_input(inp->embd);
1598
+
1599
+ if (n_embd_inp != n_embd) {
1600
+ cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
1601
+ }
1602
+ }
1603
+
1604
+ // vector embeddings path (ubatch.embd != nullptr)
1605
+ {
1606
+ auto & cur = inps[1];
1122
1607
 
1123
1608
  cur = inp->embd;
1124
1609
  }
1125
1610
 
1611
+ assert(ggml_are_same_shape (inps[0], inps[1]));
1612
+ assert(ggml_are_same_stride(inps[0], inps[1]));
1613
+
1614
+ ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
1615
+
1616
+ if (n_embd_inp != n_embd) {
1617
+ cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
1618
+ }
1619
+
1620
+ res->t_inp_embd = cur;
1621
+
1126
1622
  // For Granite architecture
1127
1623
  if (hparams.f_embedding_scale != 0.0f) {
1128
1624
  cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
1129
1625
  }
1130
1626
 
1131
- cb(cur, "inp_embd", -1);
1627
+ cb(cur, "embd", -1);
1132
1628
 
1133
1629
  res->add_input(std::move(inp));
1134
1630
 
1631
+ // make sure the produced embeddings are immediately materialized in the ggml graph
1632
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18599
1633
+ ggml_build_forward_expand(gf, cur);
1634
+
1135
1635
  return cur;
1136
1636
  }
1137
1637
 
@@ -1149,13 +1649,14 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
1149
1649
  }
1150
1650
 
1151
1651
  ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1152
- auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
1652
+ auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
1153
1653
 
1154
1654
  auto & cur = inp->attn_scale;
1155
1655
 
1156
1656
  // this need to be 1x1xN for broadcasting
1157
1657
  cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
1158
1658
  ggml_set_input(cur);
1659
+ ggml_set_name(cur, "attn_scale");
1159
1660
 
1160
1661
  res->add_input(std::move(inp));
1161
1662
 
@@ -1165,7 +1666,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1165
1666
  ggml_tensor * llm_graph_context::build_inp_out_ids() const {
1166
1667
  // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
1167
1668
  // but this would make the graph topology depend on the number of output tokens, which can interere with
1168
- // features that require constant topology such as pipline parallelism
1669
+ // features that require constant topology such as pipeline parallelism
1169
1670
  // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
1170
1671
  //if (n_outputs < n_tokens) {
1171
1672
  // return nullptr;
@@ -1222,8 +1723,8 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1222
1723
  // return cur;
1223
1724
  //}
1224
1725
 
1225
- const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
1226
- const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1726
+ const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
1727
+ const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1227
1728
 
1228
1729
  cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
1229
1730
  ggml_set_input(cur);
@@ -1299,12 +1800,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1299
1800
  k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1300
1801
  v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1301
1802
 
1302
- const auto n_kv = k->ne[1];
1303
-
1304
1803
  ggml_tensor * cur;
1305
1804
 
1306
- // TODO: replace hardcoded padding with ggml-provided padding
1307
- if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
1805
+ const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr;
1806
+ if (use_flash_attn) {
1308
1807
  GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
1309
1808
 
1310
1809
  if (v_trans) {
@@ -1330,7 +1829,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1330
1829
  if (v_mla) {
1331
1830
  #if 0
1332
1831
  // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
1333
- // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
1832
+ // However, the code is optimized for dimensions 0 and 1 being large, so this is inefficient.
1334
1833
  cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
1335
1834
  cur = ggml_mul_mat(ctx0, v_mla, cur);
1336
1835
  #else
@@ -1419,10 +1918,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1419
1918
  auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1420
1919
 
1421
1920
  // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1422
- inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1423
- ggml_set_input(inp->kq_mask);
1921
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1922
+ ggml_set_input(inp->self_kq_mask);
1424
1923
 
1425
- inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
1924
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1925
+
1926
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1927
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1928
+ ggml_set_input(inp->self_kq_mask_swa);
1929
+
1930
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1931
+ } else {
1932
+ inp->self_kq_mask_swa = nullptr;
1933
+ inp->self_kq_mask_swa_cnv = nullptr;
1934
+ }
1426
1935
 
1427
1936
  return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
1428
1937
  }
@@ -1447,7 +1956,9 @@ ggml_tensor * llm_graph_context::build_attn(
1447
1956
  ggml_build_forward_expand(gf, k_cur);
1448
1957
  ggml_build_forward_expand(gf, v_cur);
1449
1958
 
1450
- const auto & kq_mask = inp->get_kq_mask();
1959
+ const bool is_swa = hparams.is_swa(il);
1960
+
1961
+ const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1451
1962
 
1452
1963
  // [TAG_NO_CACHE_PAD]
1453
1964
  // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
@@ -1488,14 +1999,11 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1488
1999
  {
1489
2000
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1490
2001
 
1491
- const auto n_kv = mctx_cur->get_n_kv();
1492
- const auto n_tokens = ubatch.n_tokens;
1493
- const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1494
-
1495
2002
  inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1496
2003
  inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1497
2004
 
1498
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
2005
+ inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
2006
+
1499
2007
  ggml_set_input(inp->self_kq_mask);
1500
2008
 
1501
2009
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1521,14 +2029,17 @@ ggml_tensor * llm_graph_context::build_attn(
1521
2029
  ggml_tensor * v_cur,
1522
2030
  ggml_tensor * kq_b,
1523
2031
  ggml_tensor * sinks,
1524
- ggml_tensor * v_mla,
2032
+ ggml_tensor * v_mla, // TODO: remove
1525
2033
  float kq_scale,
1526
2034
  int il) const {
2035
+ GGML_ASSERT(v_mla == nullptr);
2036
+
1527
2037
  // these nodes are added to the graph together so that they are not reordered
1528
2038
  // by doing so, the number of splits in the graph is reduced
2039
+ // expand k later to enable rope fusion which directly writes into k-v cache
1529
2040
  ggml_build_forward_expand(gf, q_cur);
1530
- ggml_build_forward_expand(gf, k_cur);
1531
2041
  ggml_build_forward_expand(gf, v_cur);
2042
+ ggml_build_forward_expand(gf, k_cur);
1532
2043
 
1533
2044
  const auto * mctx_cur = inp->mctx;
1534
2045
 
@@ -1550,6 +2061,89 @@ ggml_tensor * llm_graph_context::build_attn(
1550
2061
  ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1551
2062
  cb(cur, "kqv_out", il);
1552
2063
 
2064
+ if (wo) {
2065
+ cur = build_lora_mm(wo, cur);
2066
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
2067
+ // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
2068
+ ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
2069
+ }
2070
+ }
2071
+
2072
+ if (wo_b) {
2073
+ cur = ggml_add(ctx0, cur, wo_b);
2074
+ }
2075
+
2076
+ return cur;
2077
+ }
2078
+
2079
+ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
2080
+ ggml_context * ctx0,
2081
+ const llama_ubatch & ubatch,
2082
+ const llama_hparams & hparams,
2083
+ const llama_cparams & cparams,
2084
+ const llama_kv_cache_context * mctx_cur) {
2085
+
2086
+ auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
2087
+
2088
+ {
2089
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
2090
+
2091
+ inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
2092
+
2093
+ inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
2094
+ ggml_set_input(inp->self_kq_mask);
2095
+
2096
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
2097
+ }
2098
+
2099
+ return inp;
2100
+ }
2101
+
2102
+ llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
2103
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
2104
+
2105
+ auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
2106
+
2107
+ return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
2108
+ }
2109
+
2110
+ ggml_tensor * llm_graph_context::build_attn(
2111
+ llm_graph_input_attn_k * inp,
2112
+ ggml_tensor * wo,
2113
+ ggml_tensor * wo_b,
2114
+ ggml_tensor * q_cur,
2115
+ ggml_tensor * k_cur,
2116
+ ggml_tensor * v_cur,
2117
+ ggml_tensor * kq_b,
2118
+ ggml_tensor * sinks,
2119
+ ggml_tensor * v_mla,
2120
+ float kq_scale,
2121
+ int il) const {
2122
+ // these nodes are added to the graph together so that they are not reordered
2123
+ // by doing so, the number of splits in the graph is reduced
2124
+ // expand k later to enable rope fusion which directly writes into k-v cache
2125
+ ggml_build_forward_expand(gf, q_cur);
2126
+ ggml_build_forward_expand(gf, v_cur);
2127
+ ggml_build_forward_expand(gf, k_cur);
2128
+
2129
+ const auto * mctx_cur = inp->mctx;
2130
+
2131
+ // store to KV cache
2132
+ {
2133
+ const auto & k_idxs = inp->get_k_idxs();
2134
+
2135
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
2136
+ }
2137
+
2138
+ const auto & kq_mask = inp->get_kq_mask();
2139
+
2140
+ ggml_tensor * q = q_cur;
2141
+ ggml_tensor * k = mctx_cur->get_k(ctx0, il);
2142
+ ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
2143
+
2144
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
2145
+ cb(cur, "kqv_out", il);
2146
+
1553
2147
  if (wo) {
1554
2148
  cur = build_lora_mm(wo, cur);
1555
2149
  if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
@@ -1637,7 +2231,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1637
2231
 
1638
2232
  const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1639
2233
 
1640
- inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
2234
+ inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
1641
2235
  ggml_set_input(inp->cross_kq_mask);
1642
2236
 
1643
2237
  inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1695,32 +2289,30 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
1695
2289
 
1696
2290
  auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
1697
2291
 
1698
- const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1699
-
1700
2292
  {
1701
- const auto n_kv = mctx_cur->get_base()->get_n_kv();
1702
-
1703
2293
  inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1704
2294
  inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1705
2295
 
1706
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
2296
+ inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
1707
2297
  ggml_set_input(inp->self_kq_mask);
2298
+ ggml_set_name(inp->self_kq_mask, "self_kq_mask");
1708
2299
 
1709
2300
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
2301
+ ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
1710
2302
  }
1711
2303
 
1712
2304
  {
1713
2305
  GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
1714
2306
 
1715
- const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1716
-
1717
2307
  inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1718
2308
  inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1719
2309
 
1720
- inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
2310
+ inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
1721
2311
  ggml_set_input(inp->self_kq_mask_swa);
2312
+ ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
1722
2313
 
1723
2314
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
2315
+ ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
1724
2316
  }
1725
2317
 
1726
2318
  return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
@@ -1777,6 +2369,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
1777
2369
  inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
1778
2370
  inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
1779
2371
 
2372
+ inp->head = mctx_cur->get_head();
2373
+ inp->rs_z = mctx_cur->get_rs_z();
2374
+
1780
2375
  return inp;
1781
2376
  }
1782
2377
 
@@ -1845,19 +2440,91 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1845
2440
  llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1846
2441
  const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1847
2442
 
1848
- auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
2443
+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
1849
2444
  auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1850
2445
 
1851
- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
2446
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
1852
2447
 
1853
2448
  return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1854
2449
  }
1855
2450
 
2451
+ llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const {
2452
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2453
+
2454
+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
2455
+ auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2456
+
2457
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid_k>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2458
+
2459
+ return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp));
2460
+ }
2461
+
2462
+ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
2463
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
2464
+
2465
+ auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
2466
+
2467
+ // build iswa attention input
2468
+ const auto * attn_ctx = mctx_cur->get_attn();
2469
+
2470
+ auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
2471
+
2472
+ {
2473
+ inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
2474
+ inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
2475
+
2476
+ inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
2477
+ ggml_set_input(inp_attn->self_kq_mask);
2478
+
2479
+ inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
2480
+ }
2481
+
2482
+ {
2483
+ inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
2484
+ inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
2485
+
2486
+ inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
2487
+ ggml_set_input(inp_attn->self_kq_mask_swa);
2488
+
2489
+ inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
2490
+ }
2491
+
2492
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2493
+
2494
+ return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
2495
+ }
2496
+
2497
+ void llm_graph_context::build_dense_out(
2498
+ ggml_tensor * dense_2,
2499
+ ggml_tensor * dense_2_b,
2500
+ ggml_tensor * dense_3) const {
2501
+ if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) {
2502
+ return;
2503
+ }
2504
+ ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
2505
+ GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
2506
+
2507
+ if (dense_2) {
2508
+ cur = ggml_mul_mat(ctx0, dense_2, cur);
2509
+ }
2510
+ if (dense_2_b) {
2511
+ cur = ggml_add(ctx0, cur, dense_2_b);
2512
+ }
2513
+ if (dense_3) {
2514
+ cur = ggml_mul_mat(ctx0, dense_3, cur);
2515
+ }
2516
+ cb(cur, "result_embd_pooled", -1);
2517
+ res->t_embd_pooled = cur;
2518
+ ggml_build_forward_expand(gf, cur);
2519
+ }
2520
+
2521
+
1856
2522
  void llm_graph_context::build_pooling(
1857
2523
  ggml_tensor * cls,
1858
2524
  ggml_tensor * cls_b,
1859
2525
  ggml_tensor * cls_out,
1860
- ggml_tensor * cls_out_b) const {
2526
+ ggml_tensor * cls_out_b,
2527
+ ggml_tensor * cls_norm) const {
1861
2528
  if (!cparams.embeddings) {
1862
2529
  return;
1863
2530
  }
@@ -1896,8 +2563,15 @@ void llm_graph_context::build_pooling(
1896
2563
  } break;
1897
2564
  case LLAMA_POOLING_TYPE_RANK:
1898
2565
  {
1899
- ggml_tensor * inp_cls = build_inp_cls();
1900
- cur = ggml_get_rows(ctx0, inp, inp_cls);
2566
+ if (arch == LLM_ARCH_MODERN_BERT) {
2567
+ // modern bert gte reranker builds mean first then applies prediction head and classifier
2568
+ // https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411
2569
+ ggml_tensor * inp_mean = build_inp_mean();
2570
+ cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
2571
+ } else {
2572
+ ggml_tensor * inp_cls = build_inp_cls();
2573
+ cur = ggml_get_rows(ctx0, inp, inp_cls);
2574
+ }
1901
2575
 
1902
2576
  // classification head
1903
2577
  // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
@@ -1906,7 +2580,15 @@ void llm_graph_context::build_pooling(
1906
2580
  if (cls_b) {
1907
2581
  cur = ggml_add(ctx0, cur, cls_b);
1908
2582
  }
1909
- cur = ggml_tanh(ctx0, cur);
2583
+ if (arch == LLM_ARCH_MODERN_BERT) {
2584
+ cur = ggml_gelu(ctx0, cur);
2585
+ } else {
2586
+ cur = ggml_tanh(ctx0, cur);
2587
+ }
2588
+ if (cls_norm) {
2589
+ // head norm
2590
+ cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1);
2591
+ }
1910
2592
  }
1911
2593
 
1912
2594
  // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
@@ -1921,7 +2603,7 @@ void llm_graph_context::build_pooling(
1921
2603
  }
1922
2604
 
1923
2605
  // softmax for qwen3 reranker
1924
- if (arch == LLM_ARCH_QWEN3) {
2606
+ if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
1925
2607
  cur = ggml_soft_max(ctx0, cur);
1926
2608
  }
1927
2609
  } break;
@@ -1937,6 +2619,94 @@ void llm_graph_context::build_pooling(
1937
2619
  ggml_build_forward_expand(gf, cur);
1938
2620
  }
1939
2621
 
2622
+ void llm_graph_context::build_sampling() const {
2623
+ if (samplers.empty() || !res->t_logits) {
2624
+ return;
2625
+ }
2626
+
2627
+ std::array<ggml_tensor *, 2> outs;
2628
+ outs[0] = res->t_logits;
2629
+
2630
+ auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
2631
+ res->add_input(std::move(inp_sampling));
2632
+
2633
+ std::map<llama_seq_id, int32_t> seq_to_logit_row;
2634
+ int32_t logit_row_idx = 0;
2635
+
2636
+ for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
2637
+ if (ubatch.output[i]) {
2638
+ llama_seq_id seq_id = ubatch.seq_id[i][0];
2639
+ seq_to_logit_row[seq_id] = logit_row_idx;
2640
+ logit_row_idx++;
2641
+ }
2642
+ }
2643
+
2644
+ // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
2645
+ GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
2646
+
2647
+ // add a dummy row of logits
2648
+ // this trick makes the graph static, regardless of which samplers are activated
2649
+ // this is important in order to minimize graph reallocations
2650
+ ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
2651
+
2652
+ for (const auto & [seq_id, sampler] : samplers) {
2653
+ const auto it = seq_to_logit_row.find(seq_id);
2654
+
2655
+ // inactive samplers always work on the first row
2656
+ const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
2657
+ const int i_out = it != seq_to_logit_row.end() ? 1 : 0;
2658
+
2659
+ ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
2660
+ ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
2661
+
2662
+ struct llama_sampler_data data = {
2663
+ /*.logits =*/ logits_seq,
2664
+ /*.probs =*/ nullptr,
2665
+ /*.sampled =*/ nullptr,
2666
+ /*.candidates =*/ nullptr,
2667
+ };
2668
+
2669
+ assert(sampler->iface->backend_apply);
2670
+ sampler->iface->backend_apply(sampler, ctx0, gf, &data);
2671
+
2672
+ if (data.sampled != nullptr) {
2673
+ res->t_sampled[seq_id] = data.sampled;
2674
+ outs[1] = data.sampled;
2675
+ ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2676
+ }
2677
+
2678
+ if (data.probs != nullptr) {
2679
+ res->t_sampled_probs[seq_id] = data.probs;
2680
+ outs[1] = data.probs;
2681
+ ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2682
+ }
2683
+
2684
+ if (data.logits != nullptr) {
2685
+ res->t_sampled_logits[seq_id] = data.logits;
2686
+ outs[1] = data.logits;
2687
+ ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2688
+ }
2689
+
2690
+ if (data.candidates != nullptr) {
2691
+ res->t_candidates[seq_id] = data.candidates;
2692
+ outs[1] = data.candidates;
2693
+ ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
2694
+ }
2695
+ }
2696
+
2697
+ // TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
2698
+ /*
2699
+ for (const auto & [seq_id, sampler] : samplers) {
2700
+ if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
2701
+ ggml_tensor * selected_token = it->second;
2702
+ if (selected_token != nullptr) {
2703
+ llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
2704
+ }
2705
+ }
2706
+ }
2707
+ */
2708
+ }
2709
+
1940
2710
  int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
1941
2711
  // TODO move to hparams if a T5 variant appears that uses a different value
1942
2712
  const int64_t max_distance = 128;
@@ -1952,7 +2722,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
1952
2722
 
1953
2723
  if (bidirectional) {
1954
2724
  relative_bucket += (relative_position > 0) * n_buckets;
1955
- relative_position = abs(relative_position);
2725
+ relative_position = std::abs(relative_position);
1956
2726
  } else {
1957
2727
  relative_position = -std::min<int32_t>(relative_position, 0);
1958
2728
  }