whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -1,20 +1,21 @@
1
1
  #include "ggml.h"
2
2
  #include "common.cuh"
3
- #include "convert.cuh"
3
+ #include "unary.cuh"
4
4
  #include "mmvf.cuh"
5
+ #include "convert.cuh"
5
6
 
6
- template <typename T, typename type_acc, int ncols_dst, int block_size>
7
+ template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
7
8
  static __global__ void mul_mat_vec_f(
8
- const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
9
+ const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
9
10
  const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
10
- const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
11
- const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
11
+ const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
12
+ const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
12
13
  const int row = blockIdx.x;
13
14
  const int channel_dst = blockIdx.y;
14
- const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
15
+ const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
15
16
  const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
16
17
  const int sample_dst = blockIdx.z;
17
- const int sample_x = sample_dst / sample_ratio;
18
+ const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
18
19
  const int sample_y = sample_dst;
19
20
  const int tid = threadIdx.x;
20
21
 
@@ -24,58 +25,164 @@ static __global__ void mul_mat_vec_f(
24
25
  y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
25
26
  dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
26
27
 
28
+ bool use_gate = false;
29
+ bool use_bias = false;
30
+ bool use_gate_bias = false;
31
+ ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
32
+ const T * gate_x = nullptr;
33
+ const float * x_bias = nullptr;
34
+ const float * gate_bias = nullptr;
35
+
36
+ if constexpr (has_fusion) {
37
+ use_gate = fusion.gate != nullptr;
38
+ use_bias = fusion.x_bias != nullptr;
39
+ use_gate_bias = fusion.gate_bias != nullptr;
40
+ glu_op = fusion.glu_op;
41
+
42
+ if (use_gate) {
43
+ gate_x = static_cast<const T *>(fusion.gate);
44
+ }
45
+ if (use_bias) {
46
+ x_bias = static_cast<const float *>(fusion.x_bias);
47
+ }
48
+ if (use_gate_bias) {
49
+ gate_bias = static_cast<const float *>(fusion.gate_bias);
50
+ use_gate_bias = use_gate;
51
+ } else {
52
+ use_gate_bias = false;
53
+ }
54
+ }
55
+
56
+ if (use_gate) {
57
+ gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
58
+ }
59
+ if constexpr (has_fusion) {
60
+ const int channel_bias = ids ? channel_x : channel_dst;
61
+ if (use_bias) {
62
+ x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
63
+ }
64
+ if (use_gate_bias) {
65
+ gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
66
+ }
67
+ }
68
+
27
69
  const float2 * y2 = (const float2 *) y;
28
70
 
29
71
  extern __shared__ char data_mmv[];
30
72
  float * buf_iw = (float *) data_mmv;
73
+ float * buf_iw_gate = nullptr;
74
+ if constexpr (has_fusion) {
75
+ buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
76
+ }
31
77
 
32
78
  if (block_size > warp_size) {
33
79
  if (tid < warp_size) {
34
80
  buf_iw[tid] = 0.0f;
81
+ if constexpr (has_fusion) {
82
+ if (use_gate) {
83
+ buf_iw_gate[tid] = 0.0f;
84
+ }
85
+ }
35
86
  }
36
87
  __syncthreads();
37
88
  }
38
89
 
39
90
  float sumf[ncols_dst] = {0.0f};
91
+ float sumf_gate[ncols_dst];
92
+ if constexpr (has_fusion) {
93
+ #pragma unroll
94
+ for (int j = 0; j < ncols_dst; ++j) {
95
+ sumf_gate[j] = 0.0f;
96
+ }
97
+ }
40
98
 
41
99
  if constexpr (std::is_same_v<T, float>) {
42
100
  const float2 * x2 = (const float2 *) x;
101
+ const float2 * gate_x2 = nullptr;
102
+ if constexpr (has_fusion) {
103
+ if (use_gate) {
104
+ gate_x2 = (const float2 *) gate_x;
105
+ }
106
+ }
43
107
 
44
108
  for (int col2 = tid; col2 < ncols2; col2 += block_size) {
45
109
  const float2 tmpx = x2[col2];
110
+ float2 tmpx_gate = make_float2(0.0f, 0.0f);
111
+ if constexpr (has_fusion) {
112
+ if (use_gate) {
113
+ tmpx_gate = gate_x2[col2];
114
+ }
115
+ }
46
116
 
47
117
  #pragma unroll
48
118
  for (int j = 0; j < ncols_dst; ++j) {
49
119
  const float2 tmpy = y2[j*stride_col_y2 + col2];
50
- sumf[j] += tmpx.x*tmpy.x;
51
- sumf[j] += tmpx.y*tmpy.y;
120
+ ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
121
+ ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
122
+
123
+ if constexpr (has_fusion) {
124
+ if (use_gate) {
125
+ ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
126
+ ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
127
+ }
128
+ }
52
129
  }
53
130
  }
54
131
  } else if constexpr (std::is_same_v<T, half>) {
55
132
  const half2 * x2 = (const half2 *) x;
133
+ const half2 * gate_x2 = nullptr;
134
+ if constexpr (has_fusion) {
135
+ if (use_gate) {
136
+ gate_x2 = (const half2 *) gate_x;
137
+ }
138
+ }
56
139
 
57
140
  if (std::is_same_v<type_acc, float>) {
58
141
  for (int col2 = tid; col2 < ncols2; col2 += block_size) {
59
142
  const float2 tmpx = __half22float2(x2[col2]);
60
-
143
+ float2 tmpx_gate = make_float2(0.0f, 0.0f);
144
+ if constexpr (has_fusion) {
145
+ if (use_gate) {
146
+ tmpx_gate = __half22float2(gate_x2[col2]);
147
+ }
148
+ }
61
149
  #pragma unroll
62
150
  for (int j = 0; j < ncols_dst; ++j) {
63
151
  const float2 tmpy = y2[j*stride_col_y2 + col2];
64
- sumf[j] += tmpx.x * tmpy.x;
65
- sumf[j] += tmpx.y * tmpy.y;
152
+ ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
153
+ ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
154
+
155
+ if constexpr (has_fusion) {
156
+ if (use_gate) {
157
+ ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
158
+ ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
159
+ }
160
+ }
66
161
  }
67
162
  }
68
163
  } else {
69
164
  #ifdef FP16_AVAILABLE
70
165
  half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
166
+ half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
71
167
 
72
168
  for (int col2 = tid; col2 < ncols2; col2 += block_size) {
73
169
  const half2 tmpx = x2[col2];
74
-
170
+ half2 tmpx_gate = make_half2(0.0f, 0.0f);
171
+ if constexpr (has_fusion) {
172
+ if (use_gate) {
173
+ tmpx_gate = gate_x2[col2];
174
+ }
175
+ }
75
176
  #pragma unroll
76
177
  for (int j = 0; j < ncols_dst; ++j) {
77
178
  const float2 tmpy = y2[j*stride_col_y2 + col2];
78
179
  sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
180
+
181
+ if constexpr (has_fusion) {
182
+ if (use_gate) {
183
+ sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
184
+ }
185
+ }
79
186
  }
80
187
  }
81
188
 
@@ -83,21 +190,86 @@ static __global__ void mul_mat_vec_f(
83
190
  for (int j = 0; j < ncols_dst; ++j) {
84
191
  sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
85
192
  }
193
+
194
+ if constexpr (has_fusion) {
195
+ if (use_gate) {
196
+ #pragma unroll
197
+ for (int j = 0; j < ncols_dst; ++j) {
198
+ sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
199
+ }
200
+ }
201
+ }
86
202
  #else
87
203
  NO_DEVICE_CODE;
88
204
  #endif // FP16_AVAILABLE
89
205
  }
90
206
  } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
207
+ //TODO: add support for ggml_cuda_mad for hip_bfloat162
208
+ #if defined(GGML_USE_HIP)
91
209
  const int * x2 = (const int *) x;
210
+ const int * gate_x2 = nullptr;
211
+ if constexpr (has_fusion) {
212
+ if (use_gate) {
213
+ gate_x2 = (const int *) gate_x;
214
+ }
215
+ }
92
216
  for (int col2 = tid; col2 < ncols2; col2 += block_size) {
93
217
  const int tmpx = x2[col2];
218
+ int tmpx_gate = 0;
219
+ if constexpr (has_fusion) {
220
+ if (use_gate) {
221
+ tmpx_gate = gate_x2[col2];
222
+ }
223
+ }
224
+ #pragma unroll
225
+ for (int j = 0; j < ncols_dst; ++j) {
226
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
227
+ const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
228
+ const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
229
+ ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
230
+ ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
231
+
232
+ if constexpr (has_fusion) {
233
+ if (use_gate) {
234
+ const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
235
+ const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
236
+ ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
237
+ ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
238
+ }
239
+ }
240
+ }
241
+ }
242
+ #else
243
+ const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
244
+ const nv_bfloat162 * gate_x2 = nullptr;
245
+ if constexpr (has_fusion) {
246
+ if (use_gate) {
247
+ gate_x2 = (const nv_bfloat162 *) gate_x;
248
+ }
249
+ }
250
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
251
+ const nv_bfloat162 tmpx = x2[col2];
252
+ nv_bfloat162 tmpx_gate;
253
+ if constexpr (has_fusion) {
254
+ if (use_gate) {
255
+ tmpx_gate = gate_x2[col2];
256
+ }
257
+ }
94
258
  #pragma unroll
95
259
  for (int j = 0; j < ncols_dst; ++j) {
96
260
  const float2 tmpy = y2[j*stride_col_y2 + col2];
97
- sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
98
- sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
261
+ ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
262
+ ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
263
+
264
+ if constexpr (has_fusion) {
265
+ if (use_gate) {
266
+ ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
267
+ ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
268
+ }
269
+ }
99
270
  }
100
271
  }
272
+ #endif
101
273
  } else {
102
274
  static_assert(std::is_same_v<T, void>, "unsupported type");
103
275
  }
@@ -106,13 +278,31 @@ static __global__ void mul_mat_vec_f(
106
278
  for (int j = 0; j < ncols_dst; ++j) {
107
279
  sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
108
280
 
281
+ if constexpr (has_fusion) {
282
+ if (use_gate) {
283
+ sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
284
+ }
285
+ }
286
+
109
287
  if (block_size > warp_size) {
110
288
  buf_iw[tid/warp_size] = sumf[j];
289
+ if constexpr (has_fusion) {
290
+ if (use_gate) {
291
+ buf_iw_gate[tid/warp_size] = sumf_gate[j];
292
+ }
293
+ }
111
294
  __syncthreads();
112
295
  if (tid < warp_size) {
113
296
  sumf[j] = buf_iw[tid];
114
297
  sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
298
+ if constexpr (has_fusion) {
299
+ if (use_gate) {
300
+ sumf_gate[j] = buf_iw_gate[tid];
301
+ sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
302
+ }
303
+ }
115
304
  }
305
+
116
306
  if (j < ncols_dst) {
117
307
  __syncthreads();
118
308
  }
@@ -123,12 +313,74 @@ static __global__ void mul_mat_vec_f(
123
313
  return;
124
314
  }
125
315
 
126
- dst[tid*stride_col_dst + row] = sumf[tid];
316
+ float value = sumf[tid];
317
+
318
+ if constexpr (has_fusion) {
319
+ if (use_bias) {
320
+ value += x_bias[tid*stride_col_dst + row];
321
+ }
322
+
323
+ if (use_gate) {
324
+ float gate_value = sumf_gate[tid];
325
+ if (use_gate_bias) {
326
+ gate_value += gate_bias[tid*stride_col_dst + row];
327
+ }
328
+ switch (glu_op) {
329
+ case GGML_GLU_OP_SWIGLU:
330
+ value *= ggml_cuda_op_silu_single(gate_value);
331
+ break;
332
+ case GGML_GLU_OP_GEGLU:
333
+ value *= ggml_cuda_op_gelu_single(gate_value);
334
+ break;
335
+ case GGML_GLU_OP_SWIGLU_OAI: {
336
+ value = ggml_cuda_op_swiglu_oai_single(gate_value, value);
337
+ break;
338
+ }
339
+ default:
340
+ break;
341
+ }
342
+ }
343
+ }
344
+
345
+ dst[tid*stride_col_dst + row] = value;
346
+
347
+ if constexpr (!has_fusion) {
348
+ GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);
349
+ }
350
+ }
351
+
352
+ template<typename T, typename type_acc, int ncols_dst, int block_size>
353
+ static void mul_mat_vec_f_switch_fusion(
354
+ const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
355
+ const int64_t ncols, const int64_t nrows,
356
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
357
+ const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
358
+ const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
359
+ const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
360
+
361
+ const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
362
+ if constexpr (ncols_dst == 1) {
363
+ if (has_fusion) {
364
+ mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
365
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
366
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
367
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
368
+ return;
369
+ }
370
+ }
371
+
372
+ GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
373
+
374
+ mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
375
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
376
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
377
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
378
+
127
379
  }
128
380
 
129
381
  template <typename T, typename type_acc, int ncols_dst>
130
- static void launch_mul_mat_vec_f_cuda(
131
- const T * x, const float * y, const int32_t * ids, float * dst,
382
+ void launch_mul_mat_vec_f_cuda(
383
+ const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
132
384
  const int64_t ncols, const int64_t nrows,
133
385
  const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
134
386
  const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
@@ -140,8 +392,8 @@ static void launch_mul_mat_vec_f_cuda(
140
392
  GGML_ASSERT(stride_col_y % 2 == 0);
141
393
  GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
142
394
  GGML_ASSERT( nsamples_dst % nsamples_x == 0);
143
- const int64_t channel_ratio = nchannels_dst / nchannels_x;
144
- const int64_t sample_ratio = nsamples_dst / nsamples_x;
395
+ const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
396
+ const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
145
397
 
146
398
  const int device = ggml_cuda_get_device();
147
399
  const int warp_size = ggml_cuda_info().devices[device].warp_size;
@@ -160,57 +412,59 @@ static void launch_mul_mat_vec_f_cuda(
160
412
  }
161
413
  }
162
414
 
163
- const int nbytes_shared = warp_size*sizeof(float);
415
+ const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
416
+
417
+ const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
164
418
  const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
165
419
  const dim3 block_dims(block_size_best, 1, 1);
166
420
  switch (block_size_best) {
167
421
  case 32: {
168
- mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
169
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
170
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
171
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
422
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
423
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
424
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
425
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
172
426
  } break;
173
427
  case 64: {
174
- mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
175
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
176
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
177
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
428
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
429
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
430
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
431
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
178
432
  } break;
179
433
  case 96: {
180
- mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
181
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
182
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
183
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
434
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
435
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
436
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
437
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
184
438
  } break;
185
439
  case 128: {
186
- mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
187
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
188
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
189
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
440
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
441
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
442
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
443
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
190
444
  } break;
191
445
  case 160: {
192
- mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
193
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
194
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
195
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
446
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
447
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
448
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
449
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
196
450
  } break;
197
451
  case 192: {
198
- mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
199
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
200
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
201
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
452
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
453
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
454
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
455
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
202
456
  } break;
203
457
  case 224: {
204
- mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
205
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
206
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
207
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
458
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
459
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
460
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
461
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
208
462
  } break;
209
463
  case 256: {
210
- mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
211
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
212
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
213
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
464
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
465
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
466
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
467
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
214
468
  } break;
215
469
  default: {
216
470
  GGML_ABORT("fatal error");
@@ -220,7 +474,7 @@ static void launch_mul_mat_vec_f_cuda(
220
474
 
221
475
  template <typename T, typename type_acc>
222
476
  static void mul_mat_vec_f_cuda_switch_ncols_dst(
223
- const T * x, const float * y, const int32_t * ids, float * dst,
477
+ const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
224
478
  const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
225
479
  const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
226
480
  const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
@@ -230,49 +484,49 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
230
484
  switch (ncols_dst) {
231
485
  case 1:
232
486
  launch_mul_mat_vec_f_cuda<T, type_acc, 1>
233
- (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
487
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
234
488
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
235
489
  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
236
490
  break;
237
491
  case 2:
238
492
  launch_mul_mat_vec_f_cuda<T, type_acc, 2>
239
- (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
493
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
240
494
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
241
495
  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
242
496
  break;
243
497
  case 3:
244
498
  launch_mul_mat_vec_f_cuda<T, type_acc, 3>
245
- (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
499
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
246
500
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
247
501
  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
248
502
  break;
249
503
  case 4:
250
504
  launch_mul_mat_vec_f_cuda<T, type_acc, 4>
251
- (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
505
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
252
506
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
253
507
  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
254
508
  break;
255
509
  case 5:
256
510
  launch_mul_mat_vec_f_cuda<T, type_acc, 5>
257
- (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
511
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
258
512
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
259
513
  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
260
514
  break;
261
515
  case 6:
262
516
  launch_mul_mat_vec_f_cuda<T, type_acc, 6>
263
- (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
517
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
264
518
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
265
519
  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
266
520
  break;
267
521
  case 7:
268
522
  launch_mul_mat_vec_f_cuda<T, type_acc, 7>
269
- (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
523
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
270
524
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
271
525
  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
272
526
  break;
273
527
  case 8:
274
528
  launch_mul_mat_vec_f_cuda<T, type_acc, 8>
275
- (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
529
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
276
530
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
277
531
  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
278
532
  break;
@@ -284,29 +538,31 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
284
538
 
285
539
  template<typename T>
286
540
  static void mul_mat_vec_f_cuda(
287
- const T * x, const float * y, const int32_t * ids, float * dst,
541
+ const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
288
542
  const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
289
543
  const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
290
544
  const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
291
545
  const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
292
546
  const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
293
547
  enum ggml_prec prec, cudaStream_t stream) {
548
+
294
549
  if constexpr(std::is_same_v<T, half>) {
295
550
  if (prec == GGML_PREC_DEFAULT) {
296
551
  mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
297
- (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
298
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
299
- stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
552
+ (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
553
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
554
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
300
555
  return;
301
556
  }
302
557
  }
303
558
  mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
304
- (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
305
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
306
- stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
559
+ (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
560
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
561
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
307
562
  }
308
563
 
309
- void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
564
+ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
565
+ const ggml_cuda_mm_fusion_args_host * fusion) {
310
566
  GGML_ASSERT( src1->type == GGML_TYPE_F32);
311
567
  GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
312
568
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -332,6 +588,30 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
332
588
  const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
333
589
  float * dst_d = (float *) dst->data;
334
590
 
591
+ ggml_cuda_mm_fusion_args_device fusion_local{};
592
+
593
+ if (fusion) {
594
+ GGML_ASSERT( !ids || dst->ne[2] == 1);
595
+ GGML_ASSERT( ids || dst->ne[1] == 1);
596
+ if (fusion->x_bias) {
597
+ GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
598
+ GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
599
+ GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
600
+ fusion_local.x_bias = fusion->x_bias->data;
601
+ }
602
+ if (fusion->gate) {
603
+ GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
604
+ fusion_local.gate = fusion->gate->data;
605
+ }
606
+ if (fusion->gate_bias) {
607
+ GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
608
+ GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
609
+ GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
610
+ fusion_local.gate_bias = fusion->gate_bias->data;
611
+ }
612
+ fusion_local.glu_op = fusion->glu_op;
613
+ }
614
+
335
615
  const int64_t s01 = src0->nb[1] / ts_src0;
336
616
  const int64_t s11 = src1->nb[1] / ts_src1;
337
617
  const int64_t s1 = dst->nb[1] / ts_dst;
@@ -354,19 +634,19 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
354
634
  switch (src0->type) {
355
635
  case GGML_TYPE_F32: {
356
636
  const float * src0_d = (const float *) src0->data;
357
- mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
637
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
358
638
  ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
359
639
  ne03, ne3, s03, s13, s3, prec, ctx.stream());
360
640
  } break;
361
641
  case GGML_TYPE_F16: {
362
642
  const half * src0_d = (const half *) src0->data;
363
- mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
643
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
364
644
  ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
365
645
  ne03, ne3, s03, s13, s3, prec, ctx.stream());
366
646
  } break;
367
647
  case GGML_TYPE_BF16: {
368
648
  const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
369
- mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
649
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
370
650
  ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
371
651
  ne03, ne3, s03, s13, s3, prec, ctx.stream());
372
652
  } break;
@@ -393,7 +673,6 @@ void ggml_cuda_op_mul_mat_vec_f(
393
673
  const int cc = ggml_cuda_info().devices[id].cc;
394
674
  const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
395
675
 
396
-
397
676
  // ggml_cuda_op provides single, contiguous matrices
398
677
  const int64_t stride_row = ne00;
399
678
  const int64_t stride_col_y = ne10;
@@ -410,22 +689,23 @@ void ggml_cuda_op_mul_mat_vec_f(
410
689
  const int64_t stride_sample_y = 0;
411
690
  const int64_t stride_sample_dst = 0;
412
691
 
692
+ ggml_cuda_mm_fusion_args_device empty{};
413
693
  switch (src0->type) {
414
694
  case GGML_TYPE_F32: {
415
695
  const float * src0_d = (const float *) src0_dd_i;
416
- mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
696
+ mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
417
697
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
418
698
  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
419
699
  } break;
420
700
  case GGML_TYPE_F16: {
421
701
  const half * src0_d = (const half *) src0_dd_i;
422
- mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
702
+ mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
423
703
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
424
704
  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
425
705
  } break;
426
706
  case GGML_TYPE_BF16: {
427
707
  const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
428
- mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
708
+ mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
429
709
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
430
710
  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
431
711
  } break;
@@ -436,10 +716,23 @@ void ggml_cuda_op_mul_mat_vec_f(
436
716
  GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size);
437
717
  }
438
718
 
439
- bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
719
+ bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) {
440
720
  if (src0_ne[0] % 2 != 0) {
441
721
  return false;
442
722
  }
723
+
724
+ const size_t ts = ggml_type_size(type);
725
+ if (src0_nb[0] != ts) {
726
+ return false;
727
+ }
728
+
729
+ // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
730
+ for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
731
+ if (src0_nb[i] % (2*ts) != 0) {
732
+ return false;
733
+ }
734
+ }
735
+
443
736
  switch (type) {
444
737
  case GGML_TYPE_F32:
445
738
  if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
@@ -472,7 +765,10 @@ bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0
472
765
  return ne11 <= 8;
473
766
  } else if (GGML_CUDA_CC_IS_AMD(cc)) {
474
767
  if (fp16_mma_hardware_available(cc)) {
475
- if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
768
+ if (GGML_CUDA_CC_IS_RDNA3(cc)) {
769
+ return ne11 <= 3;
770
+ }
771
+ if (GGML_CUDA_CC_IS_RDNA4(cc)) {
476
772
  return ne11 <= 5;
477
773
  }
478
774
  return ne11 <= 2;