whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -4,14 +4,15 @@
4
4
  #include "llama-batch.h"
5
5
  #include "llama-cparams.h"
6
6
 
7
- #include "llama-kv-cache-unified.h"
8
- #include "llama-kv-cache-unified-iswa.h"
7
+ #include "llama-kv-cache.h"
8
+ #include "llama-kv-cache-iswa.h"
9
9
  #include "llama-memory-hybrid.h"
10
10
  #include "llama-memory-recurrent.h"
11
11
 
12
12
  #include <cassert>
13
13
  #include <cmath>
14
14
  #include <cstring>
15
+ #include <unordered_set>
15
16
 
16
17
  void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
17
18
  if (ubatch->token) {
@@ -28,6 +29,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
28
29
  }
29
30
  }
30
31
 
32
+ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
33
+ bool res = true;
34
+
35
+ res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
36
+ res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
37
+
38
+ return res;
39
+ }
40
+
31
41
  void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
32
42
  if (ubatch->pos && pos) {
33
43
  const int64_t n_tokens = ubatch->n_tokens;
@@ -50,15 +60,26 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
50
60
  }
51
61
  }
52
62
 
63
+ bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
64
+ bool res = true;
65
+
66
+ res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
67
+
68
+ return res;
69
+ }
70
+
53
71
  void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
54
72
  if (ubatch->pos && attn_scale) {
55
73
  const int64_t n_tokens = ubatch->n_tokens;
56
74
 
75
+ GGML_ASSERT(f_attn_temp_scale != 0.0f);
76
+ GGML_ASSERT(n_attn_temp_floor_scale != 0);
77
+
57
78
  std::vector<float> attn_scale_data(n_tokens, 0.0f);
58
79
  for (int i = 0; i < n_tokens; ++i) {
59
80
  const float pos = ubatch->pos[i];
60
81
  attn_scale_data[i] = std::log(
61
- std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
82
+ std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
62
83
  ) * f_attn_temp_scale + 1.0;
63
84
  }
64
85
 
@@ -71,7 +92,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
71
92
  const int64_t n_tokens = ubatch->n_tokens;
72
93
 
73
94
  GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
74
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
95
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
75
96
 
76
97
  int32_t * data = (int32_t *) pos_bucket->data;
77
98
 
@@ -118,6 +139,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
118
139
  }
119
140
  }
120
141
 
142
+ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
143
+ bool res = true;
144
+
145
+ res &= n_outputs == params.n_outputs;
146
+
147
+ return res;
148
+ }
149
+
121
150
  void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
122
151
  if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
123
152
  const int64_t n_tokens = ubatch->n_tokens;
@@ -163,38 +192,26 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
163
192
 
164
193
  void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
165
194
  const int64_t n_tokens = ubatch->n_tokens;
166
- const int64_t n_seq_tokens = ubatch->n_seq_tokens;
167
195
  const int64_t n_seqs_unq = ubatch->n_seqs_unq;
168
196
 
169
197
  if (cparams.embeddings && (
170
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
171
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
172
- )) {
198
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
199
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
200
+ cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
201
+ )) {
173
202
  GGML_ASSERT(cls);
174
203
  GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
175
204
 
176
205
  uint32_t * data = (uint32_t *) cls->data;
177
206
  memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
178
207
 
179
- for (int i = 0; i < n_tokens; i += n_seq_tokens) {
180
- for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
181
- const llama_seq_id seq_id = ubatch->seq_id[i][s];
182
- const int32_t seq_idx = ubatch->seq_idx[seq_id];
183
-
184
- data[seq_idx] = i;
185
- }
186
- }
187
- }
188
-
189
- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
190
- GGML_ASSERT(cls);
191
- GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
192
-
193
- uint32_t * data = (uint32_t *) cls->data;
194
- memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
208
+ std::vector<int> target_pos(n_seqs_unq, -1);
209
+ std::vector<int> target_row(n_seqs_unq, -1);
195
210
 
196
- std::vector<int> last_pos(n_seqs_unq, -1);
197
- std::vector<int> last_row(n_seqs_unq, -1);
211
+ const bool last = (
212
+ cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
213
+ (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
214
+ );
198
215
 
199
216
  for (int i = 0; i < n_tokens; ++i) {
200
217
  const llama_pos pos = ubatch->pos[i];
@@ -203,16 +220,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
203
220
  const llama_seq_id seq_id = ubatch->seq_id[i][s];
204
221
  const int32_t seq_idx = ubatch->seq_idx[seq_id];
205
222
 
206
- if (pos >= last_pos[seq_idx]) {
207
- last_pos[seq_idx] = pos;
208
- last_row[seq_idx] = i;
223
+ if (
224
+ (target_pos[seq_idx] == -1) ||
225
+ ( last && pos >= target_pos[seq_idx]) ||
226
+ (!last && pos < target_pos[seq_idx])
227
+ ) {
228
+ target_pos[seq_idx] = pos;
229
+ target_row[seq_idx] = i;
209
230
  }
210
231
  }
211
232
  }
212
233
 
213
234
  for (int s = 0; s < n_seqs_unq; ++s) {
214
- if (last_row[s] >= 0) {
215
- data[s] = last_row[s];
235
+ if (target_row[s] >= 0) {
236
+ data[s] = target_row[s];
216
237
  }
217
238
  }
218
239
  }
@@ -234,6 +255,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
234
255
  }
235
256
  }
236
257
 
258
+ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
259
+ const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
260
+
261
+ this->mctx = mctx;
262
+
263
+ bool res = true;
264
+
265
+ res &= s_copy->ne[0] == mctx->get_n_rs();
266
+
267
+ res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
268
+ res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
269
+
270
+ res &= head == mctx->get_head();
271
+ res &= rs_z == mctx->get_rs_z();
272
+
273
+ return res;
274
+ }
275
+
237
276
  void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
238
277
  GGML_UNUSED(ubatch);
239
278
 
@@ -244,56 +283,164 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
244
283
  }
245
284
  }
246
285
 
286
+ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
287
+ LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
288
+ const char * swa_type_str = "unknown";
289
+
290
+ switch (swa_type) {
291
+ case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
292
+ case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
293
+ case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
294
+ case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
295
+ };
296
+
297
+ LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
298
+ LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
299
+ LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
300
+
301
+ LLAMA_LOG_DEBUG(" ");
302
+ for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
303
+ LLAMA_LOG_DEBUG("%2d", j);
304
+ }
305
+ LLAMA_LOG_DEBUG("\n");
306
+
307
+ for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
308
+ LLAMA_LOG_DEBUG(" %2d ", i);
309
+ for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
310
+ float val = data[i * n_kv + j];
311
+ if (val == -INFINITY) {
312
+ LLAMA_LOG_DEBUG(" ∞");
313
+ } else {
314
+ LLAMA_LOG_DEBUG(" 0");
315
+ }
316
+ }
317
+ LLAMA_LOG_DEBUG("\n");
318
+ }
319
+ }
320
+
247
321
  void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
248
322
  const int64_t n_kv = ubatch->n_tokens;
249
323
  const int64_t n_tokens = ubatch->n_tokens;
250
324
 
251
- GGML_ASSERT(kq_mask);
252
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
325
+ const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
326
+ for (int h = 0; h < 1; ++h) {
327
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
328
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
329
+ const llama_pos p1 = ubatch->pos[i1];
253
330
 
254
- float * data = (float *) kq_mask->data;
331
+ const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
255
332
 
256
- for (int h = 0; h < 1; ++h) {
257
- for (int i1 = 0; i1 < n_tokens; ++i1) {
258
- const llama_seq_id s1 = ubatch->seq_id[i1][0];
333
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
334
+ const llama_seq_id s0 = ubatch->seq_id[i0][0];
335
+ const llama_pos p0 = ubatch->pos[i0];
259
336
 
260
- for (int i0 = 0; i0 < n_tokens; ++i0) {
261
- float f = -INFINITY;
337
+ // mask different sequences
338
+ if (s0 != s1) {
339
+ continue;
340
+ }
262
341
 
263
- for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
264
- const llama_seq_id s0 = ubatch->seq_id[i0][0];
342
+ // mask future tokens
343
+ if (cparams.causal_attn && p0 > p1) {
344
+ continue;
345
+ }
265
346
 
266
- // TODO: reimplement this like in llama_kv_cache_unified
267
- if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
268
- if (hparams.use_alibi) {
269
- f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
270
- } else {
271
- f = 0.0f;
272
- }
273
- break;
347
+ // apply SWA if any
348
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
349
+ continue;
274
350
  }
275
- }
276
351
 
277
- data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
352
+ data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
353
+ }
278
354
  }
279
355
  }
356
+ };
357
+
358
+ {
359
+ GGML_ASSERT(self_kq_mask);
360
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
361
+
362
+ float * data = (float *) self_kq_mask->data;
363
+
364
+ std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
365
+
366
+ fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
367
+
368
+ if (debug) {
369
+ print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
370
+ }
280
371
  }
281
- }
282
372
 
283
- void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284
- if (self_kq_mask) {
285
- mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
373
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
374
+ GGML_ASSERT(self_kq_mask_swa);
375
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
376
+
377
+ float * data = (float *) self_kq_mask_swa->data;
378
+
379
+ std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
380
+
381
+ fill_mask(data, hparams.n_swa, hparams.swa_type);
382
+
383
+ if (debug) {
384
+ print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
385
+ }
286
386
  }
287
387
  }
288
388
 
289
- void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
290
- if (self_kq_mask) {
291
- mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
292
- }
389
+ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
390
+ mctx->set_input_k_idxs(self_k_idxs, ubatch);
391
+ mctx->set_input_v_idxs(self_v_idxs, ubatch);
293
392
 
294
- if (self_kq_mask_swa) {
295
- mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
296
- }
393
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
394
+ }
395
+
396
+ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
397
+ const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
398
+
399
+ this->mctx = mctx;
400
+
401
+ bool res = true;
402
+
403
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
404
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
405
+
406
+ res &= self_kq_mask->ne[0] == mctx->get_n_kv();
407
+ res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
408
+
409
+ return res;
410
+ }
411
+
412
+ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
413
+ mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
414
+ mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
415
+
416
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
417
+
418
+ mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
419
+ mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
420
+
421
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
422
+ }
423
+
424
+ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
425
+ const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
426
+
427
+ this->mctx = mctx;
428
+
429
+ bool res = true;
430
+
431
+ res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
432
+ //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
433
+
434
+ res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
435
+ //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
436
+
437
+ res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
438
+ res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
439
+
440
+ res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
441
+ res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
442
+
443
+ return res;
297
444
  }
298
445
 
299
446
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@@ -303,7 +450,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
303
450
  const int64_t n_tokens = ubatch->n_tokens;
304
451
 
305
452
  GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
306
- GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
453
+ GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
307
454
 
308
455
  float * data = (float *) cross_kq_mask->data;
309
456
 
@@ -324,7 +471,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
324
471
  }
325
472
  }
326
473
 
327
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
474
+ for (int i = n_tokens; i < n_tokens; ++i) {
328
475
  for (int j = 0; j < n_enc; ++j) {
329
476
  data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
330
477
  }
@@ -333,15 +480,16 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
333
480
  }
334
481
 
335
482
  void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
336
- if (self_kq_mask) {
337
- mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
338
- }
483
+ mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
484
+ mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
485
+
486
+ mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
339
487
 
340
488
  const int64_t n_rs = mctx->get_recr()->get_n_rs();
341
489
 
342
- if (s_copy) {
343
- GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
344
- int32_t * data = (int32_t *) s_copy->data;
490
+ if (inp_rs->s_copy) {
491
+ GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
492
+ int32_t * data = (int32_t *) inp_rs->s_copy->data;
345
493
 
346
494
  // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
347
495
  for (uint32_t i = 0; i < n_rs; ++i) {
@@ -350,10 +498,186 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
350
498
  }
351
499
  }
352
500
 
353
- void llm_graph_input_one::set_input(const llama_ubatch *) {
354
- GGML_ASSERT(one && ggml_nelements(one) == 1);
355
- float f_one = 1.0f;
356
- ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
501
+ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
502
+ const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
503
+
504
+ this->mctx = mctx;
505
+
506
+ bool res = true;
507
+
508
+ res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
509
+ //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
510
+
511
+ res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
512
+ res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
513
+
514
+ res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
515
+
516
+ res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
517
+ res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
518
+
519
+ res &= inp_rs->head == mctx->get_recr()->get_head();
520
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
521
+
522
+ return res;
523
+ }
524
+
525
+ void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
526
+ // set the inputs only for the active samplers in the current ubatch
527
+ std::unordered_set<llama_seq_id> active_samplers;
528
+ for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
529
+ if (ubatch->output[i]) {
530
+ llama_seq_id seq_id = ubatch->seq_id[i][0];
531
+ active_samplers.insert(seq_id);
532
+ }
533
+ }
534
+
535
+ for (auto seq_id : active_samplers) {
536
+ if (samplers.find(seq_id) == samplers.end()) {
537
+ continue;
538
+ }
539
+
540
+ auto & sampler = samplers[seq_id];
541
+
542
+ if (sampler->iface->backend_set_input) {
543
+ sampler->iface->backend_set_input(sampler);
544
+ }
545
+ }
546
+ }
547
+
548
+ bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
549
+ if (samplers.size() != params.samplers.size()) {
550
+ return false;
551
+ }
552
+
553
+ for (const auto & [seq_id, sampler] : params.samplers) {
554
+ if (samplers[seq_id] != sampler) {
555
+ return false;
556
+ }
557
+ }
558
+
559
+ return true;
560
+ }
561
+
562
+ //
563
+ // llm_graph_result
564
+ //
565
+
566
+ llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
567
+ reset();
568
+
569
+ const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
570
+ debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
571
+ }
572
+
573
+ int64_t llm_graph_result::get_max_nodes() const {
574
+ return max_nodes;
575
+ }
576
+
577
+ void llm_graph_result::reset() {
578
+ t_tokens = nullptr;
579
+ t_logits = nullptr;
580
+ t_embd = nullptr;
581
+ t_embd_pooled = nullptr;
582
+ t_sampled.clear();
583
+ t_sampled_probs.clear();
584
+ t_sampled_logits.clear();
585
+ t_candidates.clear();
586
+
587
+ params = {};
588
+
589
+ inputs.clear();
590
+
591
+ buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
592
+
593
+ ggml_init_params params = {
594
+ /*.mem_size =*/ buf_compute_meta.size(),
595
+ /*.mem_buffer =*/ buf_compute_meta.data(),
596
+ /*.no_alloc =*/ true,
597
+ };
598
+
599
+ ctx_compute.reset(ggml_init(params));
600
+
601
+ gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
602
+ }
603
+
604
+ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
605
+ for (auto & input : inputs) {
606
+ input->set_input(ubatch);
607
+ }
608
+ }
609
+
610
+ void llm_graph_result::set_outputs() {
611
+ if (t_logits != nullptr) {
612
+ ggml_set_output(t_logits);
613
+ }
614
+ if (t_embd != nullptr) {
615
+ ggml_set_output(t_embd);
616
+ }
617
+ if (t_embd_pooled != nullptr) {
618
+ ggml_set_output(t_embd_pooled);
619
+ }
620
+ for (auto & [seq_id, t] : t_sampled) {
621
+ if (t != nullptr) {
622
+ ggml_set_output(t);
623
+ }
624
+ }
625
+ for (auto & [seq_id, t] : t_sampled_probs) {
626
+ if (t != nullptr) {
627
+ ggml_set_output(t);
628
+ }
629
+ }
630
+ for (auto & [seq_id, t] : t_sampled_logits) {
631
+ if (t != nullptr) {
632
+ ggml_set_output(t);
633
+ }
634
+ }
635
+ for (auto & [seq_id, t] : t_candidates) {
636
+ if (t != nullptr) {
637
+ ggml_set_output(t);
638
+ }
639
+ }
640
+ }
641
+
642
+ bool llm_graph_result::can_reuse(const llm_graph_params & params) {
643
+ if (!this->params.allow_reuse(params)) {
644
+ if (debug > 1) {
645
+ LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
646
+ }
647
+
648
+ return false;
649
+ }
650
+
651
+ if (debug > 1) {
652
+ LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
653
+ }
654
+
655
+ bool res = true;
656
+
657
+ for (auto & input : inputs) {
658
+ const bool cur = input->can_reuse(params);
659
+
660
+ if (debug > 1) {
661
+ LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
662
+ }
663
+
664
+ res = res && cur;
665
+ }
666
+
667
+ if (debug > 0) {
668
+ LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
669
+ }
670
+
671
+ return res;
672
+ }
673
+
674
+ llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
675
+ inputs.emplace_back(std::move(input));
676
+ return inputs.back().get();
677
+ }
678
+
679
+ void llm_graph_result::set_params(const llm_graph_params & params) {
680
+ this->params = params;
357
681
  }
358
682
 
359
683
  //
@@ -390,15 +714,18 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
390
714
  n_ctx_orig (cparams.n_ctx_orig_yarn),
391
715
  pooling_type (cparams.pooling_type),
392
716
  rope_type (hparams.rope_type),
393
- ctx0 (params.ctx),
394
717
  sched (params.sched),
395
718
  backend_cpu (params.backend_cpu),
396
719
  cvec (params.cvec),
397
720
  loras (params.loras),
398
721
  mctx (params.mctx),
399
722
  cross (params.cross),
723
+ samplers (params.samplers),
400
724
  cb_func (params.cb),
401
- res (std::make_unique<llm_graph_result>()) {
725
+ res (params.res),
726
+ ctx0 (res->get_ctx()),
727
+ gf (res->get_gf()) {
728
+ res->set_params(params);
402
729
  }
403
730
 
404
731
  void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -613,6 +940,8 @@ ggml_tensor * llm_graph_context::build_ffn(
613
940
  cur = ggml_reglu(ctx0, cur);
614
941
  cb(cur, "ffn_reglu", il);
615
942
  } break;
943
+ default:
944
+ GGML_ABORT("fatal error");
616
945
  }
617
946
 
618
947
  if (gate && type_gate == LLM_FFN_PAR) {
@@ -622,8 +951,8 @@ ggml_tensor * llm_graph_context::build_ffn(
622
951
 
623
952
  if (down) {
624
953
  cur = build_lora_mm(down, cur);
625
- if (arch == LLM_ARCH_GLM4) {
626
- // GLM4 seems to have numerical issues with half-precision accumulators
954
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
955
+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
627
956
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
628
957
  }
629
958
  }
@@ -658,13 +987,64 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
658
987
  bool scale_w,
659
988
  float w_scale,
660
989
  llama_expert_gating_func_type gating_op,
661
- int il) const {
990
+ int il,
991
+ ggml_tensor * probs_in) const {
992
+ return build_moe_ffn(
993
+ cur,
994
+ gate_inp, /* gate_inp_b */ nullptr,
995
+ up_exps, /* up_exps_b */ nullptr,
996
+ gate_exps, /* gate_exps_b */ nullptr,
997
+ down_exps, /* down_exps_b */ nullptr,
998
+ exp_probs_b,
999
+ n_expert,
1000
+ n_expert_used,
1001
+ type_op,
1002
+ norm_w,
1003
+ scale_w,
1004
+ w_scale,
1005
+ gating_op,
1006
+ il,
1007
+ probs_in
1008
+ );
1009
+ }
1010
+
1011
+ ggml_tensor * llm_graph_context::build_moe_ffn(
1012
+ ggml_tensor * cur,
1013
+ ggml_tensor * gate_inp,
1014
+ ggml_tensor * gate_inp_b,
1015
+ ggml_tensor * up_exps,
1016
+ ggml_tensor * up_exps_b,
1017
+ ggml_tensor * gate_exps,
1018
+ ggml_tensor * gate_exps_b,
1019
+ ggml_tensor * down_exps,
1020
+ ggml_tensor * down_exps_b,
1021
+ ggml_tensor * exp_probs_b,
1022
+ int64_t n_expert,
1023
+ int64_t n_expert_used,
1024
+ llm_ffn_op_type type_op,
1025
+ bool norm_w,
1026
+ bool scale_w,
1027
+ float w_scale,
1028
+ llama_expert_gating_func_type gating_op,
1029
+ int il,
1030
+ ggml_tensor * probs_in) const {
662
1031
  const int64_t n_embd = cur->ne[0];
663
1032
  const int64_t n_tokens = cur->ne[1];
664
1033
  const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
665
1034
 
666
- ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
667
- cb(logits, "ffn_moe_logits", il);
1035
+ ggml_tensor * logits = nullptr;
1036
+
1037
+ if (probs_in == nullptr) {
1038
+ logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
1039
+ cb(logits, "ffn_moe_logits", il);
1040
+ } else {
1041
+ logits = probs_in;
1042
+ }
1043
+
1044
+ if (gate_inp_b) {
1045
+ logits = ggml_add(ctx0, logits, gate_inp_b);
1046
+ cb(logits, "ffn_moe_logits_biased", il);
1047
+ }
668
1048
 
669
1049
  ggml_tensor * probs = nullptr;
670
1050
  switch (gating_op) {
@@ -676,6 +1056,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
676
1056
  {
677
1057
  probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
678
1058
  } break;
1059
+ case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
1060
+ {
1061
+ probs = logits; // [n_expert, n_tokens]
1062
+ } break;
679
1063
  default:
680
1064
  GGML_ABORT("fatal error");
681
1065
  }
@@ -695,21 +1079,71 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
695
1079
  selection_probs = logits;
696
1080
  }
697
1081
 
1082
+ if (arch == LLM_ARCH_GROVEMOE) {
1083
+ selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
1084
+ cb(selection_probs, "ffn_moe_probs_biased", il);
1085
+ }
1086
+
1087
+ // select top n_group_used expert groups
1088
+ // https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
1089
+ if (hparams.n_expert_groups > 1 && n_tokens > 0) {
1090
+ const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
1091
+
1092
+ // organize experts into n_expert_groups
1093
+ ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
1094
+
1095
+ ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
1096
+ group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
1097
+
1098
+ // get top n_group_used expert groups
1099
+ group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
1100
+ group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
1101
+
1102
+ ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
1103
+ cb(expert_groups, "ffn_moe_group_topk", il);
1104
+
1105
+ // mask out the other groups
1106
+ selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
1107
+ selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
1108
+ selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
1109
+ cb(selection_probs, "ffn_moe_probs_masked", il);
1110
+ }
1111
+
698
1112
  // select experts
699
- ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
1113
+ ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
700
1114
  cb(selected_experts->src[0], "ffn_moe_argsort", il);
701
1115
  cb(selected_experts, "ffn_moe_topk", il);
702
1116
 
703
- ggml_tensor * weights = ggml_get_rows(ctx0,
704
- ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
1117
+ if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
1118
+ // TODO: Use scalar div instead when/if implemented
1119
+ ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
1120
+ selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
1121
+ probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
1122
+ } else {
1123
+ probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
1124
+ }
1125
+
1126
+ ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
705
1127
  cb(weights, "ffn_moe_weights", il);
706
1128
 
1129
+
1130
+ if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
1131
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
1132
+ weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
1133
+ weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
1134
+ cb(weights, "ffn_moe_weights_softmax", il);
1135
+ }
1136
+
707
1137
  if (norm_w) {
708
1138
  weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
709
1139
 
710
1140
  ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
711
1141
  cb(weights_sum, "ffn_moe_weights_sum", il);
712
1142
 
1143
+ // Avoid division by zero, clamp to smallest number representable by F16
1144
+ weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
1145
+ cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
1146
+
713
1147
  weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
714
1148
  cb(weights, "ffn_moe_weights_norm", il);
715
1149
 
@@ -720,6 +1154,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
720
1154
  cb(weights, "ffn_moe_weights_scaled", il);
721
1155
  }
722
1156
 
1157
+ //call early so that topk-moe can be used
1158
+ ggml_build_forward_expand(gf, weights);
1159
+
723
1160
  cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
724
1161
 
725
1162
  if (weight_before_ffn) {
@@ -732,6 +1169,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
732
1169
  ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
733
1170
  cb(up, "ffn_moe_up", il);
734
1171
 
1172
+ if (up_exps_b) {
1173
+ up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
1174
+ cb(up, "ffn_moe_up_biased", il);
1175
+ }
1176
+
735
1177
  ggml_tensor * experts = nullptr;
736
1178
  if (gate_exps) {
737
1179
  cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
@@ -740,6 +1182,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
740
1182
  cur = up;
741
1183
  }
742
1184
 
1185
+ if (gate_exps_b) {
1186
+ cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
1187
+ cb(cur, "ffn_moe_gate_biased", il);
1188
+ }
1189
+
743
1190
  switch (type_op) {
744
1191
  case LLM_FFN_SILU:
745
1192
  if (gate_exps) {
@@ -757,6 +1204,31 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
757
1204
  cur = ggml_gelu(ctx0, cur);
758
1205
  cb(cur, "ffn_moe_gelu", il);
759
1206
  } break;
1207
+ case LLM_FFN_SWIGLU_OAI_MOE:
1208
+ {
1209
+ // TODO: move to hparams?
1210
+ constexpr float alpha = 1.702f;
1211
+ constexpr float limit = 7.0f;
1212
+ cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
1213
+ cb(cur, "ffn_moe_swiglu_oai", il);
1214
+ } break;
1215
+ case LLM_FFN_RELU:
1216
+ if (gate_exps) {
1217
+ cur = ggml_reglu_split(ctx0, cur, up);
1218
+ cb(cur, "ffn_moe_reglu", il);
1219
+ } else {
1220
+ cur = ggml_relu(ctx0, cur);
1221
+ cb(cur, "ffn_moe_relu", il);
1222
+ } break;
1223
+ case LLM_FFN_RELU_SQR:
1224
+ if (gate_exps) {
1225
+ // TODO: add support for gated squared relu
1226
+ GGML_ABORT("fatal error: gated squared relu not implemented");
1227
+ } else {
1228
+ cur = ggml_relu(ctx0, cur);
1229
+ cur = ggml_sqr(ctx0, cur);
1230
+ cb(cur, "ffn_moe_relu_sqr", il);
1231
+ } break;
760
1232
  default:
761
1233
  GGML_ABORT("fatal error");
762
1234
  }
@@ -764,25 +1236,38 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
764
1236
  experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
765
1237
  cb(experts, "ffn_moe_down", il);
766
1238
 
1239
+ if (down_exps_b) {
1240
+ experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
1241
+ cb(experts, "ffn_moe_down_biased", il);
1242
+ }
1243
+
767
1244
  if (!weight_before_ffn) {
768
1245
  experts = ggml_mul(ctx0, experts, weights);
769
1246
  cb(cur, "ffn_moe_weighted", il);
770
1247
  }
771
1248
 
1249
+ ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
1250
+
1251
+ assert(n_expert_used > 0);
1252
+
1253
+ // order the views before the adds
1254
+ for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
1255
+ cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
1256
+
1257
+ ggml_build_forward_expand(gf, cur_experts[i]);
1258
+ }
1259
+
772
1260
  // aggregate experts
773
- ggml_tensor * moe_out = nullptr;
774
- for (int i = 0; i < n_expert_used; ++i) {
775
- ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
776
- experts->nb[2], i*experts->nb[1]);
1261
+ // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
1262
+ // to avoid potentially a large number of add nodes during warmup
1263
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14753
1264
+ ggml_tensor * moe_out = cur_experts[0];
777
1265
 
778
- if (i == 0) {
779
- moe_out = cur_expert;
780
- } else {
781
- moe_out = ggml_add(ctx0, moe_out, cur_expert);
782
- }
1266
+ for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
1267
+ moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
783
1268
  }
784
1269
 
785
- if (n_expert_used == 1) {
1270
+ if (hparams.n_expert_used == 1) {
786
1271
  // avoid returning a non-contiguous tensor
787
1272
  moe_out = ggml_cont(ctx0, moe_out);
788
1273
  }
@@ -794,7 +1279,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
794
1279
 
795
1280
  // input embeddings with optional lora
796
1281
  ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
797
- const int64_t n_embd = hparams.n_embd;
1282
+ const int64_t n_embd = hparams.n_embd_inp();
798
1283
 
799
1284
  auto inp = std::make_unique<llm_graph_input_embd>();
800
1285
 
@@ -841,6 +1326,10 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
841
1326
 
842
1327
  res->add_input(std::move(inp));
843
1328
 
1329
+ // make sure the produced embeddings are immediately materialized in the ggml graph
1330
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18599
1331
+ ggml_build_forward_expand(gf, cur);
1332
+
844
1333
  return cur;
845
1334
  }
846
1335
 
@@ -858,7 +1347,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
858
1347
  }
859
1348
 
860
1349
  ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
861
- auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
1350
+ auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
862
1351
 
863
1352
  auto & cur = inp->attn_scale;
864
1353
 
@@ -906,7 +1395,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
906
1395
  }
907
1396
 
908
1397
  ggml_tensor * llm_graph_context::build_inp_cls() const {
909
- auto inp = std::make_unique<llm_graph_input_cls>(cparams);
1398
+ auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
910
1399
 
911
1400
  auto & cur = inp->cls;
912
1401
 
@@ -931,7 +1420,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
931
1420
  // return cur;
932
1421
  //}
933
1422
 
934
- const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
1423
+ const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
935
1424
  const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
936
1425
 
937
1426
  cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
@@ -956,7 +1445,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
956
1445
  }
957
1446
 
958
1447
  ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
959
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1448
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
960
1449
 
961
1450
  auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
962
1451
 
@@ -987,56 +1476,30 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
987
1476
  return pos_bias;
988
1477
  }
989
1478
 
990
- llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
991
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
992
-
993
- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
994
-
995
- {
996
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
997
-
998
- const auto n_kv = inp->mctx->get_attn()->get_n_kv();
999
-
1000
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1001
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1002
- ggml_set_input(inp->self_kq_mask);
1003
-
1004
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1005
- }
1006
-
1007
- {
1008
- const auto n_rs = mctx_cur->get_recr()->get_n_rs();
1009
-
1010
- inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1011
- ggml_set_input(inp->s_copy);
1012
- }
1013
-
1014
- return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1015
- }
1016
-
1017
1479
  ggml_tensor * llm_graph_context::build_attn_mha(
1018
- ggml_cgraph * gf,
1019
1480
  ggml_tensor * q,
1020
1481
  ggml_tensor * k,
1021
1482
  ggml_tensor * v,
1022
1483
  ggml_tensor * kq_b,
1023
1484
  ggml_tensor * kq_mask,
1485
+ ggml_tensor * sinks,
1024
1486
  ggml_tensor * v_mla,
1025
- float kq_scale) const {
1487
+ float kq_scale,
1488
+ int il) const {
1026
1489
  const bool v_trans = v->nb[1] > v->nb[2];
1027
1490
 
1491
+ // split the batch into streams if needed
1492
+ const auto n_stream = k->ne[3];
1493
+
1494
+ q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
1495
+
1028
1496
  q = ggml_permute(ctx0, q, 0, 2, 1, 3);
1029
1497
  k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1030
1498
  v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1031
1499
 
1032
- const auto n_tokens = q->ne[1];
1033
- const auto n_head = q->ne[2];
1034
- const auto n_kv = k->ne[1];
1035
-
1036
1500
  ggml_tensor * cur;
1037
1501
 
1038
- // TODO: replace hardcoded padding with ggml-provided padding
1039
- if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
1502
+ if (cparams.flash_attn && kq_b == nullptr) {
1040
1503
  GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
1041
1504
 
1042
1505
  if (v_trans) {
@@ -1054,8 +1517,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1054
1517
 
1055
1518
  cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
1056
1519
  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
1520
+ cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
1057
1521
 
1058
- ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
1522
+ ggml_flash_attn_ext_add_sinks(cur, sinks);
1523
+ ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
1059
1524
 
1060
1525
  if (v_mla) {
1061
1526
  #if 0
@@ -1068,14 +1533,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1068
1533
  // The permutations are noops and only change how the tensor data is interpreted.
1069
1534
  cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1070
1535
  cur = ggml_mul_mat(ctx0, v_mla, cur);
1536
+ cb(cur, "fattn_mla", il);
1071
1537
  cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
1072
1538
  cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
1073
1539
  #endif
1074
1540
  }
1075
1541
 
1076
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1542
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1077
1543
  } else {
1078
1544
  ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1545
+ cb(kq, "kq", il);
1079
1546
 
1080
1547
  // note: this op tends to require high floating point range
1081
1548
  // while for some models F16 is enough, for others it is not, so we default to F32 here
@@ -1083,42 +1550,54 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1083
1550
 
1084
1551
  if (arch == LLM_ARCH_GROK) {
1085
1552
  // need to do the following:
1086
- // multiply by attn_output_multiplyer of 0.08838834764831845
1553
+ // multiply by attn_output_multiplier
1087
1554
  // and then :
1088
1555
  // kq = 30 * tanh(kq / 30)
1089
1556
  // before the softmax below
1090
1557
 
1091
- kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
1092
- kq = ggml_scale(ctx0, kq, 30);
1558
+ kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
1559
+ cb(kq, "kq_tanh", il);
1560
+ kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1561
+ cb(kq, "kq_scaled", il);
1093
1562
  }
1094
1563
 
1095
1564
  if (hparams.attn_soft_cap) {
1096
1565
  kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
1566
+ cb(kq, "kq_scaled_1", il);
1097
1567
  kq = ggml_tanh (ctx0, kq);
1568
+ cb(kq, "kq_tanh", il);
1098
1569
  kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
1570
+ cb(kq, "kq_scaled_2", il);
1099
1571
  }
1100
1572
 
1101
1573
  if (kq_b) {
1102
1574
  kq = ggml_add(ctx0, kq, kq_b);
1575
+ cb(kq, "kq_plus_kq_b", il);
1103
1576
  }
1104
1577
 
1105
1578
  kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
1579
+ ggml_soft_max_add_sinks(kq, sinks);
1580
+ cb(kq, "kq_soft_max", il);
1106
1581
 
1107
1582
  if (!v_trans) {
1108
1583
  // note: avoid this branch
1109
1584
  v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
1585
+ cb(v, "v_cont", il);
1110
1586
  }
1111
1587
 
1112
1588
  ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
1589
+ cb(kqv, "kqv", il);
1113
1590
 
1114
1591
  // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
1115
1592
  if (v_mla) {
1116
1593
  kqv = ggml_mul_mat(ctx0, v_mla, kqv);
1594
+ cb(kqv, "kqv_mla", il);
1117
1595
  }
1118
1596
 
1119
1597
  cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1120
1598
 
1121
- cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
1599
+ // recombine streams
1600
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
1122
1601
 
1123
1602
  if (!cparams.offload_kqv) {
1124
1603
  // all nodes between the KV store and the attention output are run on the CPU
@@ -1135,24 +1614,33 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1135
1614
  auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1136
1615
 
1137
1616
  // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1138
- inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1139
- //cb(inp_kq_mask, "KQ_mask", -1);
1140
- ggml_set_input(inp->kq_mask);
1617
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1618
+ ggml_set_input(inp->self_kq_mask);
1619
+
1620
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1141
1621
 
1142
- inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
1622
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1623
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1624
+ ggml_set_input(inp->self_kq_mask_swa);
1625
+
1626
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1627
+ } else {
1628
+ inp->self_kq_mask_swa = nullptr;
1629
+ inp->self_kq_mask_swa_cnv = nullptr;
1630
+ }
1143
1631
 
1144
1632
  return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
1145
1633
  }
1146
1634
 
1147
1635
  ggml_tensor * llm_graph_context::build_attn(
1148
1636
  llm_graph_input_attn_no_cache * inp,
1149
- ggml_cgraph * gf,
1150
1637
  ggml_tensor * wo,
1151
1638
  ggml_tensor * wo_b,
1152
1639
  ggml_tensor * q_cur,
1153
1640
  ggml_tensor * k_cur,
1154
1641
  ggml_tensor * v_cur,
1155
1642
  ggml_tensor * kq_b,
1643
+ ggml_tensor * sinks,
1156
1644
  ggml_tensor * v_mla,
1157
1645
  float kq_scale,
1158
1646
  int il) const {
@@ -1164,13 +1652,20 @@ ggml_tensor * llm_graph_context::build_attn(
1164
1652
  ggml_build_forward_expand(gf, k_cur);
1165
1653
  ggml_build_forward_expand(gf, v_cur);
1166
1654
 
1167
- const auto & kq_mask = inp->get_kq_mask();
1655
+ const bool is_swa = hparams.is_swa(il);
1656
+
1657
+ const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1658
+
1659
+ // [TAG_NO_CACHE_PAD]
1660
+ // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
1661
+ // but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
1662
+ //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
1168
1663
 
1169
1664
  ggml_tensor * q = q_cur;
1170
1665
  ggml_tensor * k = k_cur;
1171
1666
  ggml_tensor * v = v_cur;
1172
1667
 
1173
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1668
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1174
1669
  cb(cur, "kqv_out", il);
1175
1670
 
1176
1671
  if (wo) {
@@ -1188,50 +1683,70 @@ ggml_tensor * llm_graph_context::build_attn(
1188
1683
  return cur;
1189
1684
  }
1190
1685
 
1191
- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1192
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1686
+ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1687
+ ggml_context * ctx0,
1688
+ const llama_ubatch & ubatch,
1689
+ const llama_hparams & hparams,
1690
+ const llama_cparams & cparams,
1691
+ const llama_kv_cache_context * mctx_cur) {
1193
1692
 
1194
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1693
+ auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
1195
1694
 
1196
1695
  {
1197
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1696
+ GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
1697
+
1698
+ const auto n_kv = mctx_cur->get_n_kv();
1699
+ const auto n_tokens = ubatch.n_tokens;
1700
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1198
1701
 
1199
- const auto n_kv = mctx_cur->get_n_kv();
1702
+ inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1703
+ inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1200
1704
 
1201
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1202
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1705
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
1203
1706
  ggml_set_input(inp->self_kq_mask);
1204
1707
 
1205
1708
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1206
1709
  }
1207
1710
 
1208
- return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1711
+ return inp;
1712
+ }
1713
+
1714
+ llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
1715
+ const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
1716
+
1717
+ auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1718
+
1719
+ return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
1209
1720
  }
1210
1721
 
1211
1722
  ggml_tensor * llm_graph_context::build_attn(
1212
- llm_graph_input_attn_kv_unified * inp,
1213
- ggml_cgraph * gf,
1723
+ llm_graph_input_attn_kv * inp,
1214
1724
  ggml_tensor * wo,
1215
1725
  ggml_tensor * wo_b,
1216
1726
  ggml_tensor * q_cur,
1217
1727
  ggml_tensor * k_cur,
1218
1728
  ggml_tensor * v_cur,
1219
1729
  ggml_tensor * kq_b,
1730
+ ggml_tensor * sinks,
1220
1731
  ggml_tensor * v_mla,
1221
1732
  float kq_scale,
1222
1733
  int il) const {
1223
1734
  // these nodes are added to the graph together so that they are not reordered
1224
1735
  // by doing so, the number of splits in the graph is reduced
1736
+ // expand k later to enable rope fusion which directly writes into k-v cache
1225
1737
  ggml_build_forward_expand(gf, q_cur);
1226
- ggml_build_forward_expand(gf, k_cur);
1227
1738
  ggml_build_forward_expand(gf, v_cur);
1739
+ ggml_build_forward_expand(gf, k_cur);
1228
1740
 
1229
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1741
+ const auto * mctx_cur = inp->mctx;
1230
1742
 
1231
1743
  // store to KV cache
1232
1744
  {
1233
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1234
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1745
+ const auto & k_idxs = inp->get_k_idxs();
1746
+ const auto & v_idxs = inp->get_v_idxs();
1747
+
1748
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1749
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1235
1750
  }
1236
1751
 
1237
1752
  const auto & kq_mask = inp->get_kq_mask();
@@ -1240,13 +1755,13 @@ ggml_tensor * llm_graph_context::build_attn(
1240
1755
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1241
1756
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1242
1757
 
1243
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1758
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1244
1759
  cb(cur, "kqv_out", il);
1245
1760
 
1246
1761
  if (wo) {
1247
1762
  cur = build_lora_mm(wo, cur);
1248
- if (arch == LLM_ARCH_GLM4) {
1249
- // GLM4 seems to have numerical issues with half-precision accumulators
1763
+ if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
1764
+ // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
1250
1765
  ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1251
1766
  }
1252
1767
  }
@@ -1259,14 +1774,14 @@ ggml_tensor * llm_graph_context::build_attn(
1259
1774
  }
1260
1775
 
1261
1776
  ggml_tensor * llm_graph_context::build_attn(
1262
- llm_graph_input_attn_kv_unified_iswa * inp,
1263
- ggml_cgraph * gf,
1777
+ llm_graph_input_attn_kv_iswa * inp,
1264
1778
  ggml_tensor * wo,
1265
1779
  ggml_tensor * wo_b,
1266
1780
  ggml_tensor * q_cur,
1267
1781
  ggml_tensor * k_cur,
1268
1782
  ggml_tensor * v_cur,
1269
1783
  ggml_tensor * kq_b,
1784
+ ggml_tensor * sinks,
1270
1785
  ggml_tensor * v_mla,
1271
1786
  float kq_scale,
1272
1787
  int il) const {
@@ -1282,7 +1797,7 @@ ggml_tensor * llm_graph_context::build_attn(
1282
1797
  ggml_build_forward_expand(gf, v_cur);
1283
1798
  }
1284
1799
 
1285
- const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1800
+ const auto * mctx_iswa = inp->mctx;
1286
1801
 
1287
1802
  const bool is_swa = hparams.is_swa(il);
1288
1803
 
@@ -1290,11 +1805,15 @@ ggml_tensor * llm_graph_context::build_attn(
1290
1805
 
1291
1806
  // optionally store to KV cache
1292
1807
  if (k_cur) {
1293
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1808
+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
1809
+
1810
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1294
1811
  }
1295
1812
 
1296
1813
  if (v_cur) {
1297
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1814
+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
1815
+
1816
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1298
1817
  }
1299
1818
 
1300
1819
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1303,7 +1822,7 @@ ggml_tensor * llm_graph_context::build_attn(
1303
1822
  ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1304
1823
  ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1305
1824
 
1306
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1825
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1307
1826
  cb(cur, "kqv_out", il);
1308
1827
 
1309
1828
  if (wo) {
@@ -1326,7 +1845,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1326
1845
 
1327
1846
  const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1328
1847
 
1329
- inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1848
+ inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
1330
1849
  ggml_set_input(inp->cross_kq_mask);
1331
1850
 
1332
1851
  inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1336,13 +1855,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1336
1855
 
1337
1856
  ggml_tensor * llm_graph_context::build_attn(
1338
1857
  llm_graph_input_attn_cross * inp,
1339
- ggml_cgraph * gf,
1340
1858
  ggml_tensor * wo,
1341
1859
  ggml_tensor * wo_b,
1342
1860
  ggml_tensor * q_cur,
1343
1861
  ggml_tensor * k_cur,
1344
1862
  ggml_tensor * v_cur,
1345
1863
  ggml_tensor * kq_b,
1864
+ ggml_tensor * sinks,
1346
1865
  ggml_tensor * v_mla,
1347
1866
  float kq_scale,
1348
1867
  int il) const {
@@ -1358,7 +1877,7 @@ ggml_tensor * llm_graph_context::build_attn(
1358
1877
  ggml_tensor * k = k_cur;
1359
1878
  ggml_tensor * v = v_cur;
1360
1879
 
1361
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1880
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1362
1881
  cb(cur, "kqv_out", il);
1363
1882
 
1364
1883
  if (wo) {
@@ -1376,171 +1895,131 @@ ggml_tensor * llm_graph_context::build_attn(
1376
1895
  return cur;
1377
1896
  }
1378
1897
 
1379
- ggml_tensor * llm_graph_context::build_attn(
1380
- llm_graph_input_mem_hybrid * inp,
1381
- ggml_cgraph * gf,
1382
- ggml_tensor * wo,
1383
- ggml_tensor * wo_b,
1384
- ggml_tensor * q_cur,
1385
- ggml_tensor * k_cur,
1386
- ggml_tensor * v_cur,
1387
- ggml_tensor * kq_b,
1388
- ggml_tensor * v_mla,
1389
- float kq_scale,
1390
- int il) const {
1391
- // these nodes are added to the graph together so that they are not reordered
1392
- // by doing so, the number of splits in the graph is reduced
1393
- ggml_build_forward_expand(gf, q_cur);
1394
- ggml_build_forward_expand(gf, k_cur);
1395
- ggml_build_forward_expand(gf, v_cur);
1396
-
1397
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
1398
-
1399
- // store to KV cache
1400
- {
1401
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1402
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1403
- }
1404
-
1405
- const auto & kq_mask = inp->get_kq_mask();
1406
-
1407
- ggml_tensor * q = q_cur;
1408
- ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1409
- ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1410
-
1411
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1412
- cb(cur, "kqv_out", il);
1413
-
1414
- if (wo) {
1415
- cur = build_lora_mm(wo, cur);
1416
- if (arch == LLM_ARCH_GLM4) {
1417
- // GLM4 seems to have numerical issues with half-precision accumulators
1418
- ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1419
- }
1420
- }
1421
-
1422
- if (wo_b) {
1423
- cur = ggml_add(ctx0, cur, wo_b);
1424
- }
1425
-
1426
- return cur;
1427
- }
1898
+ // TODO: maybe separate the inner implementation into a separate function
1899
+ // like with the non-sliding window equivalent
1900
+ // once sliding-window hybrid caches are a thing.
1901
+ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
1902
+ const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
1428
1903
 
1429
- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1430
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1904
+ auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
1431
1905
 
1432
- auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1906
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1433
1907
 
1434
1908
  {
1435
1909
  const auto n_kv = mctx_cur->get_base()->get_n_kv();
1436
1910
 
1437
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1438
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1911
+ inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1912
+ inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1913
+
1914
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
1439
1915
  ggml_set_input(inp->self_kq_mask);
1916
+ ggml_set_name(inp->self_kq_mask, "self_kq_mask");
1440
1917
 
1441
1918
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1919
+ ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
1442
1920
  }
1443
1921
 
1444
1922
  {
1445
- GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
1923
+ GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
1446
1924
 
1447
1925
  const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1448
1926
 
1449
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1450
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1927
+ inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1928
+ inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1929
+
1930
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
1451
1931
  ggml_set_input(inp->self_kq_mask_swa);
1932
+ ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
1452
1933
 
1453
1934
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1935
+ ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
1454
1936
  }
1455
1937
 
1456
- return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
1938
+ return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
1457
1939
  }
1458
1940
 
1459
1941
  ggml_tensor * llm_graph_context::build_rs(
1460
- ggml_cgraph * gf,
1461
1942
  ggml_tensor * s,
1462
- ggml_tensor * state_copy,
1943
+ ggml_tensor * state_copy_main,
1944
+ ggml_tensor * state_copy_extra,
1463
1945
  int32_t state_size,
1464
1946
  int32_t n_seqs,
1465
- uint32_t n_kv,
1466
- uint32_t kv_head,
1467
- uint32_t kv_size,
1947
+ uint32_t n_rs,
1948
+ uint32_t rs_head,
1949
+ uint32_t rs_size,
1468
1950
  int32_t rs_zero,
1469
- bool avoid_copies) const {
1951
+ const llm_graph_get_rows_fn & get_state_rows) const {
1470
1952
 
1471
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1953
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
1472
1954
 
1473
1955
  // Clear a single state which will then be copied to the other cleared states.
1474
1956
  // Note that this is a no-op when the view is zero-sized.
1475
1957
  ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1476
1958
  ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1477
1959
 
1478
- ggml_tensor * output_states;
1960
+ // copy states
1961
+ // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
1962
+ // {state_size, rs_size} -> {state_size, n_seqs}
1963
+ ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
1964
+ ggml_build_forward_expand(gf, output_states);
1479
1965
 
1480
- if (!avoid_copies) {
1481
- // copy states
1482
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1483
- // {state_size, kv_size} -> {state_size, n_seqs}
1484
- output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1485
- ggml_build_forward_expand(gf, output_states);
1486
- } else {
1487
- // FIXME: make the gathering operation happen before the copy below
1488
- // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1489
- output_states = states;
1490
- }
1491
-
1492
- // copy extra states which won't be changed further (between n_seqs and n_kv)
1493
- ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1966
+ // copy extra states which won't be changed further (between n_seqs and n_rs)
1967
+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
1494
1968
  ggml_build_forward_expand(gf,
1495
1969
  ggml_cpy(ctx0,
1496
1970
  states_extra,
1497
- ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1971
+ ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
1498
1972
 
1499
1973
  return output_states;
1500
1974
  }
1501
1975
 
1502
- llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1503
- const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1976
+ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
1977
+ ggml_context * ctx0,
1978
+ const llama_ubatch & ubatch,
1979
+ const llama_memory_recurrent_context * mctx_cur) {
1504
1980
 
1505
1981
  auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1506
1982
 
1507
- const auto n_rs = mctx_cur->get_n_rs();
1983
+ const int64_t n_rs = mctx_cur->get_n_rs();
1984
+ const int64_t n_seqs = ubatch.n_seqs;
1508
1985
 
1509
1986
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1510
1987
  ggml_set_input(inp->s_copy);
1511
1988
 
1512
- return (llm_graph_input_rs *) res->add_input(std::move(inp));
1989
+ inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
1990
+ inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
1991
+
1992
+ inp->head = mctx_cur->get_head();
1993
+ inp->rs_z = mctx_cur->get_rs_z();
1994
+
1995
+ return inp;
1513
1996
  }
1514
1997
 
1515
- ggml_tensor * llm_graph_context::build_rs(
1516
- llm_graph_input_rs * inp,
1517
- ggml_cgraph * gf,
1518
- ggml_tensor * s,
1519
- int32_t state_size,
1520
- int32_t n_seqs,
1521
- bool avoid_copies) const {
1998
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1522
1999
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1523
2000
 
1524
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
2001
+ auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
2002
+
2003
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1525
2004
  }
1526
2005
 
1527
2006
  ggml_tensor * llm_graph_context::build_rs(
1528
- llm_graph_input_mem_hybrid * inp,
1529
- ggml_cgraph * gf,
2007
+ llm_graph_input_rs * inp,
1530
2008
  ggml_tensor * s,
1531
2009
  int32_t state_size,
1532
2010
  int32_t n_seqs,
1533
- bool avoid_copies) const {
1534
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
2011
+ const llm_graph_get_rows_fn & get_state_rows) const {
2012
+ const auto * kv_state = inp->mctx;
1535
2013
 
1536
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
2014
+ return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
2015
+ kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
2016
+ get_state_rows);
1537
2017
  }
1538
2018
 
1539
2019
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1540
2020
  llm_graph_input_rs * inp,
1541
- ggml_cgraph * gf,
1542
2021
  const llama_ubatch & ubatch,
1543
- int il) const {
2022
+ int il) const {
1544
2023
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1545
2024
 
1546
2025
  const auto token_shift_count = hparams.token_shift_count;
@@ -1550,7 +2029,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1550
2029
  ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1551
2030
 
1552
2031
  ggml_tensor * token_shift = build_rs(
1553
- inp, gf, token_shift_all,
2032
+ inp, token_shift_all,
1554
2033
  hparams.n_embd_r(), n_seqs);
1555
2034
 
1556
2035
  token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1578,8 +2057,39 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1578
2057
  );
1579
2058
  }
1580
2059
 
2060
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
2061
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
2062
+
2063
+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
2064
+ auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
2065
+
2066
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
2067
+
2068
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
2069
+ }
2070
+
2071
+ void llm_graph_context::build_dense_out(
2072
+ ggml_tensor * dense_2,
2073
+ ggml_tensor * dense_3) const {
2074
+ if (!cparams.embeddings || !(dense_2 || dense_3)) {
2075
+ return;
2076
+ }
2077
+ ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
2078
+ GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
2079
+
2080
+ if (dense_2) {
2081
+ cur = ggml_mul_mat(ctx0, dense_2, cur);
2082
+ }
2083
+ if (dense_3) {
2084
+ cur = ggml_mul_mat(ctx0, dense_3, cur);
2085
+ }
2086
+ cb(cur, "result_embd_pooled", -1);
2087
+ res->t_embd_pooled = cur;
2088
+ ggml_build_forward_expand(gf, cur);
2089
+ }
2090
+
2091
+
1581
2092
  void llm_graph_context::build_pooling(
1582
- ggml_cgraph * gf,
1583
2093
  ggml_tensor * cls,
1584
2094
  ggml_tensor * cls_b,
1585
2095
  ggml_tensor * cls_out,
@@ -1623,34 +2133,32 @@ void llm_graph_context::build_pooling(
1623
2133
  case LLAMA_POOLING_TYPE_RANK:
1624
2134
  {
1625
2135
  ggml_tensor * inp_cls = build_inp_cls();
1626
- inp = ggml_get_rows(ctx0, inp, inp_cls);
2136
+ cur = ggml_get_rows(ctx0, inp, inp_cls);
1627
2137
 
2138
+ // classification head
2139
+ // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1628
2140
  if (cls) {
1629
- // classification head
1630
- // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1631
- cur = ggml_mul_mat(ctx0, cls, inp);
2141
+ cur = ggml_mul_mat(ctx0, cls, cur);
1632
2142
  if (cls_b) {
1633
2143
  cur = ggml_add(ctx0, cur, cls_b);
1634
2144
  }
1635
2145
  cur = ggml_tanh(ctx0, cur);
2146
+ }
1636
2147
 
1637
- // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1638
- // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1639
- if (cls_out) {
1640
- cur = ggml_mul_mat(ctx0, cls_out, cur);
1641
- if (cls_out_b) {
1642
- cur = ggml_add(ctx0, cur, cls_out_b);
1643
- }
1644
- }
1645
- } else if (cls_out) {
1646
- // Single layer classification head (direct projection)
1647
- // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1648
- cur = ggml_mul_mat(ctx0, cls_out, inp);
2148
+ // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
2149
+ // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
2150
+ // Single layer classification head (direct projection)
2151
+ // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
2152
+ if (cls_out) {
2153
+ cur = ggml_mul_mat(ctx0, cls_out, cur);
1649
2154
  if (cls_out_b) {
1650
2155
  cur = ggml_add(ctx0, cur, cls_out_b);
1651
2156
  }
1652
- } else {
1653
- GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
2157
+ }
2158
+
2159
+ // softmax for qwen3 reranker
2160
+ if (arch == LLM_ARCH_QWEN3) {
2161
+ cur = ggml_soft_max(ctx0, cur);
1654
2162
  }
1655
2163
  } break;
1656
2164
  default:
@@ -1665,6 +2173,87 @@ void llm_graph_context::build_pooling(
1665
2173
  ggml_build_forward_expand(gf, cur);
1666
2174
  }
1667
2175
 
2176
+ void llm_graph_context::build_sampling() const {
2177
+ if (samplers.empty() || !res->t_logits) {
2178
+ return;
2179
+ }
2180
+
2181
+ auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
2182
+ res->add_input(std::move(inp_sampling));
2183
+
2184
+ std::map<llama_seq_id, int32_t> seq_to_logit_row;
2185
+ int32_t logit_row_idx = 0;
2186
+
2187
+ for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
2188
+ if (ubatch.output[i]) {
2189
+ llama_seq_id seq_id = ubatch.seq_id[i][0];
2190
+ seq_to_logit_row[seq_id] = logit_row_idx;
2191
+ logit_row_idx++;
2192
+ }
2193
+ }
2194
+
2195
+ // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
2196
+ GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
2197
+
2198
+ // add a dummy row of logits
2199
+ // this trick makes the graph static, regardless of which samplers are activated
2200
+ // this is important in order to minimize graph reallocations
2201
+ // TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
2202
+ ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
2203
+
2204
+ for (const auto & [seq_id, sampler] : samplers) {
2205
+ const auto it = seq_to_logit_row.find(seq_id);
2206
+
2207
+ // inactive samplers always work on the first row
2208
+ const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
2209
+
2210
+ ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
2211
+ ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
2212
+
2213
+ struct llama_sampler_data data = {
2214
+ /*.logits =*/ logits_seq,
2215
+ /*.probs =*/ nullptr,
2216
+ /*.sampled =*/ nullptr,
2217
+ /*.candidates =*/ nullptr,
2218
+ };
2219
+
2220
+ assert(sampler->iface->backend_apply);
2221
+ sampler->iface->backend_apply(sampler, ctx0, gf, &data);
2222
+
2223
+ if (data.sampled != nullptr) {
2224
+ res->t_sampled[seq_id] = data.sampled;
2225
+ ggml_build_forward_expand(gf, data.sampled);
2226
+ }
2227
+
2228
+ if (data.probs != nullptr) {
2229
+ res->t_sampled_probs[seq_id] = data.probs;
2230
+ ggml_build_forward_expand(gf, data.probs);
2231
+ }
2232
+
2233
+ if (data.logits != nullptr) {
2234
+ res->t_sampled_logits[seq_id] = data.logits;
2235
+ ggml_build_forward_expand(gf, data.logits);
2236
+ }
2237
+
2238
+ if (data.candidates != nullptr) {
2239
+ res->t_candidates[seq_id] = data.candidates;
2240
+ ggml_build_forward_expand(gf, data.candidates);
2241
+ }
2242
+ }
2243
+
2244
+ // TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
2245
+ /*
2246
+ for (const auto & [seq_id, sampler] : samplers) {
2247
+ if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
2248
+ ggml_tensor * selected_token = it->second;
2249
+ if (selected_token != nullptr) {
2250
+ llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
2251
+ }
2252
+ }
2253
+ }
2254
+ */
2255
+ }
2256
+
1668
2257
  int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
1669
2258
  // TODO move to hparams if a T5 variant appears that uses a different value
1670
2259
  const int64_t max_distance = 128;
@@ -1680,7 +2269,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
1680
2269
 
1681
2270
  if (bidirectional) {
1682
2271
  relative_bucket += (relative_position > 0) * n_buckets;
1683
- relative_position = abs(relative_position);
2272
+ relative_position = std::abs(relative_position);
1684
2273
  } else {
1685
2274
  relative_position = -std::min<int32_t>(relative_position, 0);
1686
2275
  }