whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -12,6 +12,7 @@
12
12
  #include <cassert>
13
13
  #include <cmath>
14
14
  #include <cstring>
15
+ #include <unordered_set>
15
16
 
16
17
  void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
17
18
  if (ubatch->token) {
@@ -32,7 +33,7 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
32
33
  bool res = true;
33
34
 
34
35
  res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
35
- res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
36
+ res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
36
37
 
37
38
  return res;
38
39
  }
@@ -62,7 +63,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
62
63
  bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
63
64
  bool res = true;
64
65
 
65
- res &= pos->ne[0] == params.ubatch.n_tokens;
66
+ res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
66
67
 
67
68
  return res;
68
69
  }
@@ -71,11 +72,14 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
71
72
  if (ubatch->pos && attn_scale) {
72
73
  const int64_t n_tokens = ubatch->n_tokens;
73
74
 
75
+ GGML_ASSERT(f_attn_temp_scale != 0.0f);
76
+ GGML_ASSERT(n_attn_temp_floor_scale != 0);
77
+
74
78
  std::vector<float> attn_scale_data(n_tokens, 0.0f);
75
79
  for (int i = 0; i < n_tokens; ++i) {
76
80
  const float pos = ubatch->pos[i];
77
81
  attn_scale_data[i] = std::log(
78
- std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
82
+ std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
79
83
  ) * f_attn_temp_scale + 1.0;
80
84
  }
81
85
 
@@ -251,6 +255,24 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
251
255
  }
252
256
  }
253
257
 
258
+ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
259
+ const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
260
+
261
+ this->mctx = mctx;
262
+
263
+ bool res = true;
264
+
265
+ res &= s_copy->ne[0] == mctx->get_n_rs();
266
+
267
+ res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
268
+ res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
269
+
270
+ res &= head == mctx->get_head();
271
+ res &= rs_z == mctx->get_rs_z();
272
+
273
+ return res;
274
+ }
275
+
254
276
  void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
255
277
  GGML_UNUSED(ubatch);
256
278
 
@@ -261,12 +283,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261
283
  }
262
284
  }
263
285
 
264
- static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
286
+ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
265
287
  LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
266
- const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
267
- (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
268
- (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
269
- (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
288
+ const char * swa_type_str = "unknown";
289
+
290
+ switch (swa_type) {
291
+ case LLAMA_SWA_TYPE_NONE: swa_type_str = "LLAMA_SWA_TYPE_NONE"; break;
292
+ case LLAMA_SWA_TYPE_STANDARD: swa_type_str = "LLAMA_SWA_TYPE_STANDARD"; break;
293
+ case LLAMA_SWA_TYPE_CHUNKED: swa_type_str = "LLAMA_SWA_TYPE_CHUNKED"; break;
294
+ case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
295
+ };
296
+
270
297
  LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
271
298
  LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
272
299
  LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@@ -295,50 +322,67 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
295
322
  const int64_t n_kv = ubatch->n_tokens;
296
323
  const int64_t n_tokens = ubatch->n_tokens;
297
324
 
298
- GGML_ASSERT(kq_mask);
299
- GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
300
-
301
- float * data = (float *) kq_mask->data;
302
-
303
- // [TAG_NO_CACHE_ISWA]
304
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
305
-
306
- for (int h = 0; h < 1; ++h) {
307
- for (int i1 = 0; i1 < n_tokens; ++i1) {
308
- const llama_seq_id s1 = ubatch->seq_id[i1][0];
325
+ const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
326
+ for (int h = 0; h < 1; ++h) {
327
+ for (int i1 = 0; i1 < n_tokens; ++i1) {
328
+ const llama_seq_id s1 = ubatch->seq_id[i1][0];
329
+ const llama_pos p1 = ubatch->pos[i1];
309
330
 
310
- for (int i0 = 0; i0 < n_tokens; ++i0) {
311
- float f = -INFINITY;
331
+ const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
312
332
 
313
- for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
333
+ for (int i0 = 0; i0 < n_tokens; ++i0) {
314
334
  const llama_seq_id s0 = ubatch->seq_id[i0][0];
335
+ const llama_pos p0 = ubatch->pos[i0];
315
336
 
337
+ // mask different sequences
316
338
  if (s0 != s1) {
317
- continue; // skip different sequences
339
+ continue;
318
340
  }
319
341
 
320
- if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
321
- continue; // skip future tokens for causal attention
342
+ // mask future tokens
343
+ if (cparams.causal_attn && p0 > p1) {
344
+ continue;
322
345
  }
323
346
 
324
- // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
325
- //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
326
- // continue; // skip masked tokens for SWA
327
- //}
328
-
329
- // TODO: reimplement this like in llama_kv_cache_unified
330
- if (hparams.use_alibi) {
331
- f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
332
- } else {
333
- f = 0.0f;
347
+ // apply SWA if any
348
+ if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
349
+ continue;
334
350
  }
351
+
352
+ data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
335
353
  }
336
- data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
337
354
  }
338
355
  }
356
+ };
357
+
358
+ {
359
+ GGML_ASSERT(self_kq_mask);
360
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
361
+
362
+ float * data = (float *) self_kq_mask->data;
363
+
364
+ std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY);
365
+
366
+ fill_mask(data, 0, LLAMA_SWA_TYPE_NONE);
367
+
368
+ if (debug) {
369
+ print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE);
370
+ }
339
371
  }
340
- if (debug) {
341
- print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
372
+
373
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
374
+ GGML_ASSERT(self_kq_mask_swa);
375
+ GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
376
+
377
+ float * data = (float *) self_kq_mask_swa->data;
378
+
379
+ std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY);
380
+
381
+ fill_mask(data, hparams.n_swa, hparams.swa_type);
382
+
383
+ if (debug) {
384
+ print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
385
+ }
342
386
  }
343
387
  }
344
388
 
@@ -360,7 +404,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
360
404
  //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
361
405
 
362
406
  res &= self_kq_mask->ne[0] == mctx->get_n_kv();
363
- res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
407
+ res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
364
408
 
365
409
  return res;
366
410
  }
@@ -391,10 +435,10 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
391
435
  //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
392
436
 
393
437
  res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
394
- res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
438
+ res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
395
439
 
396
440
  res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
397
- res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
441
+ res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
398
442
 
399
443
  return res;
400
444
  }
@@ -427,7 +471,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
427
471
  }
428
472
  }
429
473
 
430
- for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
474
+ for (int i = n_tokens; i < n_tokens; ++i) {
431
475
  for (int j = 0; j < n_enc; ++j) {
432
476
  data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
433
477
  }
@@ -436,8 +480,83 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
436
480
  }
437
481
 
438
482
  void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
439
- inp_attn->set_input(ubatch);
440
- inp_rs->set_input(ubatch);
483
+ mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
484
+ mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
485
+
486
+ mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
487
+
488
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
489
+
490
+ if (inp_rs->s_copy) {
491
+ GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
492
+ int32_t * data = (int32_t *) inp_rs->s_copy->data;
493
+
494
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
495
+ for (uint32_t i = 0; i < n_rs; ++i) {
496
+ data[i] = mctx->get_recr()->s_copy(i);
497
+ }
498
+ }
499
+ }
500
+
501
+ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
502
+ const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
503
+
504
+ this->mctx = mctx;
505
+
506
+ bool res = true;
507
+
508
+ res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
509
+ //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
510
+
511
+ res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
512
+ res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
513
+
514
+ res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
515
+
516
+ res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
517
+ res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
518
+
519
+ res &= inp_rs->head == mctx->get_recr()->get_head();
520
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
521
+
522
+ return res;
523
+ }
524
+
525
+ void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
526
+ // set the inputs only for the active samplers in the current ubatch
527
+ std::unordered_set<llama_seq_id> active_samplers;
528
+ for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
529
+ if (ubatch->output[i]) {
530
+ llama_seq_id seq_id = ubatch->seq_id[i][0];
531
+ active_samplers.insert(seq_id);
532
+ }
533
+ }
534
+
535
+ for (auto seq_id : active_samplers) {
536
+ if (samplers.find(seq_id) == samplers.end()) {
537
+ continue;
538
+ }
539
+
540
+ auto & sampler = samplers[seq_id];
541
+
542
+ if (sampler->iface->backend_set_input) {
543
+ sampler->iface->backend_set_input(sampler);
544
+ }
545
+ }
546
+ }
547
+
548
+ bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
549
+ if (samplers.size() != params.samplers.size()) {
550
+ return false;
551
+ }
552
+
553
+ for (const auto & [seq_id, sampler] : params.samplers) {
554
+ if (samplers[seq_id] != sampler) {
555
+ return false;
556
+ }
557
+ }
558
+
559
+ return true;
441
560
  }
442
561
 
443
562
  //
@@ -460,6 +579,10 @@ void llm_graph_result::reset() {
460
579
  t_logits = nullptr;
461
580
  t_embd = nullptr;
462
581
  t_embd_pooled = nullptr;
582
+ t_sampled.clear();
583
+ t_sampled_probs.clear();
584
+ t_sampled_logits.clear();
585
+ t_candidates.clear();
463
586
 
464
587
  params = {};
465
588
 
@@ -484,6 +607,38 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
484
607
  }
485
608
  }
486
609
 
610
+ void llm_graph_result::set_outputs() {
611
+ if (t_logits != nullptr) {
612
+ ggml_set_output(t_logits);
613
+ }
614
+ if (t_embd != nullptr) {
615
+ ggml_set_output(t_embd);
616
+ }
617
+ if (t_embd_pooled != nullptr) {
618
+ ggml_set_output(t_embd_pooled);
619
+ }
620
+ for (auto & [seq_id, t] : t_sampled) {
621
+ if (t != nullptr) {
622
+ ggml_set_output(t);
623
+ }
624
+ }
625
+ for (auto & [seq_id, t] : t_sampled_probs) {
626
+ if (t != nullptr) {
627
+ ggml_set_output(t);
628
+ }
629
+ }
630
+ for (auto & [seq_id, t] : t_sampled_logits) {
631
+ if (t != nullptr) {
632
+ ggml_set_output(t);
633
+ }
634
+ }
635
+ for (auto & [seq_id, t] : t_candidates) {
636
+ if (t != nullptr) {
637
+ ggml_set_output(t);
638
+ }
639
+ }
640
+ }
641
+
487
642
  bool llm_graph_result::can_reuse(const llm_graph_params & params) {
488
643
  if (!this->params.allow_reuse(params)) {
489
644
  if (debug > 1) {
@@ -565,6 +720,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
565
720
  loras (params.loras),
566
721
  mctx (params.mctx),
567
722
  cross (params.cross),
723
+ samplers (params.samplers),
568
724
  cb_func (params.cb),
569
725
  res (params.res),
570
726
  ctx0 (res->get_ctx()),
@@ -928,8 +1084,33 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
928
1084
  cb(selection_probs, "ffn_moe_probs_biased", il);
929
1085
  }
930
1086
 
1087
+ // select top n_group_used expert groups
1088
+ // https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
1089
+ if (hparams.n_expert_groups > 1 && n_tokens > 0) {
1090
+ const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
1091
+
1092
+ // organize experts into n_expert_groups
1093
+ ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
1094
+
1095
+ ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
1096
+ group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
1097
+
1098
+ // get top n_group_used expert groups
1099
+ group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
1100
+ group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
1101
+
1102
+ ggml_tensor * expert_groups = ggml_argsort_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
1103
+ cb(expert_groups, "ffn_moe_group_topk", il);
1104
+
1105
+ // mask out the other groups
1106
+ selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
1107
+ selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
1108
+ selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
1109
+ cb(selection_probs, "ffn_moe_probs_masked", il);
1110
+ }
1111
+
931
1112
  // select experts
932
- ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
1113
+ ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
933
1114
  cb(selected_experts->src[0], "ffn_moe_argsort", il);
934
1115
  cb(selected_experts, "ffn_moe_topk", il);
935
1116
 
@@ -959,6 +1140,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
959
1140
  ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
960
1141
  cb(weights_sum, "ffn_moe_weights_sum", il);
961
1142
 
1143
+ // Avoid division by zero, clamp to smallest number representable by F16
1144
+ weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
1145
+ cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
1146
+
962
1147
  weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
963
1148
  cb(weights, "ffn_moe_weights_norm", il);
964
1149
 
@@ -1035,6 +1220,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1035
1220
  cur = ggml_relu(ctx0, cur);
1036
1221
  cb(cur, "ffn_moe_relu", il);
1037
1222
  } break;
1223
+ case LLM_FFN_RELU_SQR:
1224
+ if (gate_exps) {
1225
+ // TODO: add support for gated squared relu
1226
+ GGML_ABORT("fatal error: gated squared relu not implemented");
1227
+ } else {
1228
+ cur = ggml_relu(ctx0, cur);
1229
+ cur = ggml_sqr(ctx0, cur);
1230
+ cb(cur, "ffn_moe_relu_sqr", il);
1231
+ } break;
1038
1232
  default:
1039
1233
  GGML_ABORT("fatal error");
1040
1234
  }
@@ -1085,7 +1279,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
1085
1279
 
1086
1280
  // input embeddings with optional lora
1087
1281
  ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1088
- const int64_t n_embd = hparams.n_embd;
1282
+ const int64_t n_embd = hparams.n_embd_inp();
1089
1283
 
1090
1284
  auto inp = std::make_unique<llm_graph_input_embd>();
1091
1285
 
@@ -1132,6 +1326,10 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
1132
1326
 
1133
1327
  res->add_input(std::move(inp));
1134
1328
 
1329
+ // make sure the produced embeddings are immediately materialized in the ggml graph
1330
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18599
1331
+ ggml_build_forward_expand(gf, cur);
1332
+
1135
1333
  return cur;
1136
1334
  }
1137
1335
 
@@ -1149,7 +1347,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
1149
1347
  }
1150
1348
 
1151
1349
  ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1152
- auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
1350
+ auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
1153
1351
 
1154
1352
  auto & cur = inp->attn_scale;
1155
1353
 
@@ -1222,7 +1420,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1222
1420
  // return cur;
1223
1421
  //}
1224
1422
 
1225
- const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
1423
+ const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
1226
1424
  const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1227
1425
 
1228
1426
  cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
@@ -1299,12 +1497,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1299
1497
  k = ggml_permute(ctx0, k, 0, 2, 1, 3);
1300
1498
  v = ggml_permute(ctx0, v, 0, 2, 1, 3);
1301
1499
 
1302
- const auto n_kv = k->ne[1];
1303
-
1304
1500
  ggml_tensor * cur;
1305
1501
 
1306
- // TODO: replace hardcoded padding with ggml-provided padding
1307
- if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
1502
+ if (cparams.flash_attn && kq_b == nullptr) {
1308
1503
  GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
1309
1504
 
1310
1505
  if (v_trans) {
@@ -1419,10 +1614,20 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1419
1614
  auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1420
1615
 
1421
1616
  // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1422
- inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1423
- ggml_set_input(inp->kq_mask);
1617
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1618
+ ggml_set_input(inp->self_kq_mask);
1619
+
1620
+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1621
+
1622
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
1623
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
1624
+ ggml_set_input(inp->self_kq_mask_swa);
1424
1625
 
1425
- inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
1626
+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1627
+ } else {
1628
+ inp->self_kq_mask_swa = nullptr;
1629
+ inp->self_kq_mask_swa_cnv = nullptr;
1630
+ }
1426
1631
 
1427
1632
  return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
1428
1633
  }
@@ -1447,7 +1652,9 @@ ggml_tensor * llm_graph_context::build_attn(
1447
1652
  ggml_build_forward_expand(gf, k_cur);
1448
1653
  ggml_build_forward_expand(gf, v_cur);
1449
1654
 
1450
- const auto & kq_mask = inp->get_kq_mask();
1655
+ const bool is_swa = hparams.is_swa(il);
1656
+
1657
+ const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1451
1658
 
1452
1659
  // [TAG_NO_CACHE_PAD]
1453
1660
  // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
@@ -1495,7 +1702,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
1495
1702
  inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1496
1703
  inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1497
1704
 
1498
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1705
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
1499
1706
  ggml_set_input(inp->self_kq_mask);
1500
1707
 
1501
1708
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1526,9 +1733,10 @@ ggml_tensor * llm_graph_context::build_attn(
1526
1733
  int il) const {
1527
1734
  // these nodes are added to the graph together so that they are not reordered
1528
1735
  // by doing so, the number of splits in the graph is reduced
1736
+ // expand k later to enable rope fusion which directly writes into k-v cache
1529
1737
  ggml_build_forward_expand(gf, q_cur);
1530
- ggml_build_forward_expand(gf, k_cur);
1531
1738
  ggml_build_forward_expand(gf, v_cur);
1739
+ ggml_build_forward_expand(gf, k_cur);
1532
1740
 
1533
1741
  const auto * mctx_cur = inp->mctx;
1534
1742
 
@@ -1637,7 +1845,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1637
1845
 
1638
1846
  const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1639
1847
 
1640
- inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1848
+ inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1);
1641
1849
  ggml_set_input(inp->cross_kq_mask);
1642
1850
 
1643
1851
  inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1703,10 +1911,12 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
1703
1911
  inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1704
1912
  inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1705
1913
 
1706
- inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1914
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
1707
1915
  ggml_set_input(inp->self_kq_mask);
1916
+ ggml_set_name(inp->self_kq_mask, "self_kq_mask");
1708
1917
 
1709
1918
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1919
+ ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
1710
1920
  }
1711
1921
 
1712
1922
  {
@@ -1717,10 +1927,12 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
1717
1927
  inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1718
1928
  inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1719
1929
 
1720
- inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
1930
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
1721
1931
  ggml_set_input(inp->self_kq_mask_swa);
1932
+ ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
1722
1933
 
1723
1934
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
1935
+ ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
1724
1936
  }
1725
1937
 
1726
1938
  return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
@@ -1777,6 +1989,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
1777
1989
  inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
1778
1990
  inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
1779
1991
 
1992
+ inp->head = mctx_cur->get_head();
1993
+ inp->rs_z = mctx_cur->get_rs_z();
1994
+
1780
1995
  return inp;
1781
1996
  }
1782
1997
 
@@ -1845,14 +2060,35 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1845
2060
  llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1846
2061
  const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1847
2062
 
1848
- auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
2063
+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
1849
2064
  auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1850
2065
 
1851
- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
2066
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
1852
2067
 
1853
2068
  return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1854
2069
  }
1855
2070
 
2071
+ void llm_graph_context::build_dense_out(
2072
+ ggml_tensor * dense_2,
2073
+ ggml_tensor * dense_3) const {
2074
+ if (!cparams.embeddings || !(dense_2 || dense_3)) {
2075
+ return;
2076
+ }
2077
+ ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
2078
+ GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
2079
+
2080
+ if (dense_2) {
2081
+ cur = ggml_mul_mat(ctx0, dense_2, cur);
2082
+ }
2083
+ if (dense_3) {
2084
+ cur = ggml_mul_mat(ctx0, dense_3, cur);
2085
+ }
2086
+ cb(cur, "result_embd_pooled", -1);
2087
+ res->t_embd_pooled = cur;
2088
+ ggml_build_forward_expand(gf, cur);
2089
+ }
2090
+
2091
+
1856
2092
  void llm_graph_context::build_pooling(
1857
2093
  ggml_tensor * cls,
1858
2094
  ggml_tensor * cls_b,
@@ -1937,6 +2173,87 @@ void llm_graph_context::build_pooling(
1937
2173
  ggml_build_forward_expand(gf, cur);
1938
2174
  }
1939
2175
 
2176
+ void llm_graph_context::build_sampling() const {
2177
+ if (samplers.empty() || !res->t_logits) {
2178
+ return;
2179
+ }
2180
+
2181
+ auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
2182
+ res->add_input(std::move(inp_sampling));
2183
+
2184
+ std::map<llama_seq_id, int32_t> seq_to_logit_row;
2185
+ int32_t logit_row_idx = 0;
2186
+
2187
+ for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
2188
+ if (ubatch.output[i]) {
2189
+ llama_seq_id seq_id = ubatch.seq_id[i][0];
2190
+ seq_to_logit_row[seq_id] = logit_row_idx;
2191
+ logit_row_idx++;
2192
+ }
2193
+ }
2194
+
2195
+ // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
2196
+ GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
2197
+
2198
+ // add a dummy row of logits
2199
+ // this trick makes the graph static, regardless of which samplers are activated
2200
+ // this is important in order to minimize graph reallocations
2201
+ // TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
2202
+ ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
2203
+
2204
+ for (const auto & [seq_id, sampler] : samplers) {
2205
+ const auto it = seq_to_logit_row.find(seq_id);
2206
+
2207
+ // inactive samplers always work on the first row
2208
+ const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
2209
+
2210
+ ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
2211
+ ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
2212
+
2213
+ struct llama_sampler_data data = {
2214
+ /*.logits =*/ logits_seq,
2215
+ /*.probs =*/ nullptr,
2216
+ /*.sampled =*/ nullptr,
2217
+ /*.candidates =*/ nullptr,
2218
+ };
2219
+
2220
+ assert(sampler->iface->backend_apply);
2221
+ sampler->iface->backend_apply(sampler, ctx0, gf, &data);
2222
+
2223
+ if (data.sampled != nullptr) {
2224
+ res->t_sampled[seq_id] = data.sampled;
2225
+ ggml_build_forward_expand(gf, data.sampled);
2226
+ }
2227
+
2228
+ if (data.probs != nullptr) {
2229
+ res->t_sampled_probs[seq_id] = data.probs;
2230
+ ggml_build_forward_expand(gf, data.probs);
2231
+ }
2232
+
2233
+ if (data.logits != nullptr) {
2234
+ res->t_sampled_logits[seq_id] = data.logits;
2235
+ ggml_build_forward_expand(gf, data.logits);
2236
+ }
2237
+
2238
+ if (data.candidates != nullptr) {
2239
+ res->t_candidates[seq_id] = data.candidates;
2240
+ ggml_build_forward_expand(gf, data.candidates);
2241
+ }
2242
+ }
2243
+
2244
+ // TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
2245
+ /*
2246
+ for (const auto & [seq_id, sampler] : samplers) {
2247
+ if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
2248
+ ggml_tensor * selected_token = it->second;
2249
+ if (selected_token != nullptr) {
2250
+ llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
2251
+ }
2252
+ }
2253
+ }
2254
+ */
2255
+ }
2256
+
1940
2257
  int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
1941
2258
  // TODO move to hparams if a T5 variant appears that uses a different value
1942
2259
  const int64_t max_distance = 128;
@@ -1952,7 +2269,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
1952
2269
 
1953
2270
  if (bidirectional) {
1954
2271
  relative_bucket += (relative_position > 0) * n_buckets;
1955
- relative_position = abs(relative_position);
2272
+ relative_position = std::abs(relative_position);
1956
2273
  } else {
1957
2274
  relative_position = -std::min<int32_t>(relative_position, 0);
1958
2275
  }