whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -19,7 +19,9 @@
19
19
  #include "ggml-cuda/count-equal.cuh"
20
20
  #include "ggml-cuda/cpy.cuh"
21
21
  #include "ggml-cuda/cross-entropy-loss.cuh"
22
+ #include "ggml-cuda/cumsum.cuh"
22
23
  #include "ggml-cuda/diagmask.cuh"
24
+ #include "ggml-cuda/diag.cuh"
23
25
  #include "ggml-cuda/fattn.cuh"
24
26
  #include "ggml-cuda/getrows.cuh"
25
27
  #include "ggml-cuda/im2col.cuh"
@@ -43,6 +45,7 @@
43
45
  #include "ggml-cuda/ssm-scan.cuh"
44
46
  #include "ggml-cuda/sum.cuh"
45
47
  #include "ggml-cuda/sumrows.cuh"
48
+ #include "ggml-cuda/top-k.cuh"
46
49
  #include "ggml-cuda/mean.cuh"
47
50
  #include "ggml-cuda/tsembd.cuh"
48
51
  #include "ggml-cuda/topk-moe.cuh"
@@ -50,8 +53,13 @@
50
53
  #include "ggml-cuda/upscale.cuh"
51
54
  #include "ggml-cuda/wkv.cuh"
52
55
  #include "ggml-cuda/gla.cuh"
56
+ #include "ggml-cuda/set.cuh"
53
57
  #include "ggml-cuda/set-rows.cuh"
54
58
  #include "ggml-cuda/pad_reflect_1d.cuh"
59
+ #include "ggml-cuda/solve_tri.cuh"
60
+ #include "ggml-cuda/tri.cuh"
61
+ #include "ggml-cuda/cumsum.cuh"
62
+ #include "ggml-cuda/fill.cuh"
55
63
  #include "ggml.h"
56
64
 
57
65
  #include <algorithm>
@@ -195,16 +203,6 @@ static ggml_cuda_device_info ggml_cuda_init() {
195
203
  GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
196
204
 
197
205
  int64_t total_vram = 0;
198
- #ifdef GGML_CUDA_FORCE_MMQ
199
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
200
- #else
201
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
202
- #endif // GGML_CUDA_FORCE_MMQ
203
- #ifdef GGML_CUDA_FORCE_CUBLAS
204
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
205
- #else
206
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
207
- #endif // GGML_CUDA_FORCE_CUBLAS
208
206
  GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
209
207
 
210
208
  std::vector<std::pair<int, std::string>> turing_devices_without_mma;
@@ -231,10 +229,18 @@ static ggml_cuda_device_info ggml_cuda_init() {
231
229
 
232
230
  info.default_tensor_split[id] = total_vram;
233
231
  total_vram += prop.totalGlobalMem;
234
- info.devices[id].integrated = prop.integrated;
232
+ info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034)
235
233
  info.devices[id].nsm = prop.multiProcessorCount;
236
234
  info.devices[id].smpb = prop.sharedMemPerBlock;
237
235
  info.devices[id].warp_size = prop.warpSize;
236
+
237
+ #ifndef GGML_USE_MUSA
238
+ int supports_coop_launch = 0;
239
+ CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id));
240
+ info.devices[id].supports_cooperative_launch = !!supports_coop_launch;
241
+ #else
242
+ info.devices[id].supports_cooperative_launch = false;
243
+ #endif // !(GGML_USE_MUSA)
238
244
  #if defined(GGML_USE_HIP)
239
245
  info.devices[id].smpbo = prop.sharedMemPerBlock;
240
246
 
@@ -273,6 +279,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
273
279
  } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
274
280
  turing_devices_without_mma.push_back({ id, device_name });
275
281
  }
282
+
283
+ // Temporary performance fix:
284
+ // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
285
+ // TODO: Check for future drivers the default scheduling strategy and
286
+ // remove this call again when cudaDeviceScheduleSpin is default.
287
+ if (prop.major == 12 && prop.minor == 1) {
288
+ CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
289
+ }
290
+
276
291
  #endif // defined(GGML_USE_HIP)
277
292
  }
278
293
 
@@ -511,7 +526,8 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
511
526
  };
512
527
  #endif // defined(GGML_USE_VMM)
513
528
 
514
- std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
529
+ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device,
530
+ [[maybe_unused]] int stream_no) {
515
531
  #if defined(GGML_USE_VMM)
516
532
  if (ggml_cuda_info().devices[device].vmm) {
517
533
  return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
@@ -1948,8 +1964,15 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1948
1964
 
1949
1965
  size_t src1_stride_size = sizeof(cuda_t);
1950
1966
 
1951
- dim3 block_dims(ne13, ne12);
1952
- k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1967
+ const int threads_x = 16;
1968
+ const int threads_y = 16;
1969
+ dim3 block_dims(threads_x, threads_y);
1970
+
1971
+ dim3 grid_dims(
1972
+ (ne13 + threads_x - 1) / threads_x,
1973
+ (ne12 + threads_y - 1) / threads_y
1974
+ );
1975
+ k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
1953
1976
  src0_ptr, src1_ptr, dst_t,
1954
1977
  ptrs_src.get(), ptrs_dst.get(),
1955
1978
  ne12, ne13,
@@ -1998,6 +2021,164 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1998
2021
  }
1999
2022
  }
2000
2023
 
2024
+ static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
2025
+ const ggml_tensor * ffn_gate,
2026
+ const ggml_tensor * glu,
2027
+ const ggml_tensor * ffn_up_bias = nullptr,
2028
+ const ggml_tensor * ffn_gate_bias = nullptr) {
2029
+ const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr;
2030
+
2031
+ if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) {
2032
+ return false;
2033
+ }
2034
+
2035
+ const bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU;
2036
+ const bool is_mul_mat_id = ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU;
2037
+
2038
+ GGML_ASSERT(ffn_up && ffn_gate && glu);
2039
+
2040
+ if (!is_mul_mat && !is_mul_mat_id) {
2041
+ return false;
2042
+ }
2043
+
2044
+ const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID;
2045
+
2046
+ if (has_bias) {
2047
+ if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) {
2048
+ return false;
2049
+ }
2050
+
2051
+ if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) {
2052
+ return false;
2053
+ }
2054
+
2055
+ if (expected_bias_op == GGML_OP_ADD) {
2056
+ const bool up_has_mul = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up;
2057
+ const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate;
2058
+ if (!up_has_mul || !gate_has_mul) {
2059
+ return false;
2060
+ }
2061
+ } else { // GGML_OP_ADD_ID
2062
+ if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) {
2063
+ return false;
2064
+ }
2065
+ if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) {
2066
+ return false;
2067
+ }
2068
+ }
2069
+ } else {
2070
+ if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
2071
+ return false;
2072
+ }
2073
+ }
2074
+
2075
+ if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
2076
+ !ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
2077
+ return false;
2078
+ }
2079
+
2080
+ if (ffn_up->src[1] != ffn_gate->src[1]) {
2081
+ return false;
2082
+ }
2083
+
2084
+ if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
2085
+ return false;
2086
+ }
2087
+
2088
+ static constexpr std::array<ggml_glu_op, 3> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI };
2089
+
2090
+ if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
2091
+ return false;
2092
+ }
2093
+
2094
+ if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
2095
+ return false;
2096
+ }
2097
+
2098
+ const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
2099
+ ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);
2100
+
2101
+ //TODO: add support for fusion for split buffers
2102
+ if (split) {
2103
+ return false;
2104
+ }
2105
+
2106
+ return true;
2107
+ }
2108
+
2109
+ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
2110
+ ggml_tensor * src0 = tensor->src[0];
2111
+ ggml_tensor * src1 = tensor->src[1];
2112
+ const ggml_tensor * dst = tensor;
2113
+
2114
+ const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
2115
+
2116
+ bool use_mul_mat_vec_f =
2117
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
2118
+ src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2119
+
2120
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2121
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
2122
+
2123
+ const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
2124
+ ggml_backend_buft_is_cuda_split(src1->buffer->buft);
2125
+
2126
+ //TODO: add support for fusion for split buffers
2127
+ if (split) {
2128
+ return false;
2129
+ }
2130
+
2131
+ //we only support fusion for ncols_dst = 1
2132
+ if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
2133
+ return false;
2134
+ }
2135
+
2136
+ if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
2137
+ return false;
2138
+ }
2139
+
2140
+
2141
+ return use_mul_mat_vec_f;
2142
+ }
2143
+
2144
+ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
2145
+ ggml_tensor * src0 = tensor->src[0];
2146
+ ggml_tensor * src1 = tensor->src[1];
2147
+ const ggml_tensor * dst = tensor;
2148
+
2149
+ const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
2150
+ ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
2151
+ src0->view_src;
2152
+
2153
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
2154
+ dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
2155
+
2156
+ // fusion is not universally faster on Pascal
2157
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2158
+ if (cc <= GGML_CUDA_CC_PASCAL) {
2159
+ return false;
2160
+ }
2161
+ //we only support fusion for ncols_dst = 1
2162
+ if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
2163
+ return false;
2164
+ }
2165
+
2166
+ if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
2167
+ return false;
2168
+ }
2169
+
2170
+
2171
+ const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
2172
+ ggml_backend_buft_is_cuda_split(src1->buffer->buft);
2173
+
2174
+ //TODO: add support for fusion for split buffers
2175
+ if (split) {
2176
+ return false;
2177
+ }
2178
+
2179
+ return use_mul_mat_vec_q;
2180
+ }
2181
+
2001
2182
  static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2002
2183
  const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
2003
2184
 
@@ -2030,17 +2211,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2030
2211
 
2031
2212
  const int cc = ggml_cuda_info().devices[id].cc;
2032
2213
  const int warp_size = ggml_cuda_info().devices[id].warp_size;
2033
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2034
- use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
2035
- use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
2214
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
2215
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
2216
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
2036
2217
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2037
2218
  }
2038
2219
  } else {
2039
2220
  const int cc = ggml_cuda_info().devices[ctx.device].cc;
2040
2221
  const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
2041
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2042
- use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
2043
- use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
2222
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
2223
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
2224
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
2044
2225
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2045
2226
  }
2046
2227
 
@@ -2106,12 +2287,12 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2106
2287
  return;
2107
2288
  }
2108
2289
 
2109
- if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) {
2290
+ if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
2110
2291
  ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
2111
2292
  return;
2112
2293
  }
2113
2294
 
2114
- if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) {
2295
+ if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src0->nb, src1->ne[2], /*mul_mat_id=*/true)) {
2115
2296
  ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
2116
2297
  return;
2117
2298
  }
@@ -2259,6 +2440,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2259
2440
  case GGML_OP_SET_ROWS:
2260
2441
  ggml_cuda_op_set_rows(ctx, dst);
2261
2442
  break;
2443
+ case GGML_OP_SET:
2444
+ ggml_cuda_op_set(ctx, dst);
2445
+ break;
2262
2446
  case GGML_OP_DUP:
2263
2447
  ggml_cuda_dup(ctx, dst);
2264
2448
  break;
@@ -2334,6 +2518,27 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2334
2518
  case GGML_UNARY_OP_ELU:
2335
2519
  ggml_cuda_op_elu(ctx, dst);
2336
2520
  break;
2521
+ case GGML_UNARY_OP_XIELU:
2522
+ ggml_cuda_op_xielu(ctx, dst);
2523
+ break;
2524
+ case GGML_UNARY_OP_FLOOR:
2525
+ ggml_cuda_op_floor(ctx, dst);
2526
+ break;
2527
+ case GGML_UNARY_OP_CEIL:
2528
+ ggml_cuda_op_ceil(ctx, dst);
2529
+ break;
2530
+ case GGML_UNARY_OP_ROUND:
2531
+ ggml_cuda_op_round(ctx, dst);
2532
+ break;
2533
+ case GGML_UNARY_OP_TRUNC:
2534
+ ggml_cuda_op_trunc(ctx, dst);
2535
+ break;
2536
+ case GGML_UNARY_OP_EXPM1:
2537
+ ggml_cuda_op_expm1(ctx, dst);
2538
+ break;
2539
+ case GGML_UNARY_OP_SOFTPLUS:
2540
+ ggml_cuda_op_softplus(ctx, dst);
2541
+ break;
2337
2542
  default:
2338
2543
  return false;
2339
2544
  }
@@ -2437,6 +2642,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2437
2642
  case GGML_OP_PERMUTE:
2438
2643
  case GGML_OP_TRANSPOSE:
2439
2644
  break;
2645
+ case GGML_OP_DIAG:
2646
+ ggml_cuda_op_diag(ctx, dst);
2647
+ break;
2440
2648
  case GGML_OP_DIAG_MASK_INF:
2441
2649
  ggml_cuda_op_diag_mask_inf(ctx, dst);
2442
2650
  break;
@@ -2479,6 +2687,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2479
2687
  case GGML_OP_SUM:
2480
2688
  ggml_cuda_op_sum(ctx, dst);
2481
2689
  break;
2690
+ case GGML_OP_CUMSUM:
2691
+ ggml_cuda_op_cumsum(ctx, dst);
2692
+ break;
2482
2693
  case GGML_OP_SUM_ROWS:
2483
2694
  ggml_cuda_op_sum_rows(ctx, dst);
2484
2695
  break;
@@ -2491,6 +2702,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2491
2702
  case GGML_OP_SSM_SCAN:
2492
2703
  ggml_cuda_op_ssm_scan(ctx, dst);
2493
2704
  break;
2705
+ case GGML_OP_TOP_K:
2706
+ ggml_cuda_op_top_k(ctx, dst);
2707
+ break;
2494
2708
  case GGML_OP_ARGSORT:
2495
2709
  ggml_cuda_op_argsort(ctx, dst);
2496
2710
  break;
@@ -2500,6 +2714,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2500
2714
  case GGML_OP_CROSS_ENTROPY_LOSS:
2501
2715
  ggml_cuda_cross_entropy_loss(ctx, dst);
2502
2716
  break;
2717
+ case GGML_OP_TRI:
2718
+ ggml_cuda_op_tri(ctx, dst);
2719
+ break;
2503
2720
  case GGML_OP_RWKV_WKV6:
2504
2721
  ggml_cuda_op_rwkv_wkv6(ctx, dst);
2505
2722
  break;
@@ -2518,6 +2735,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2518
2735
  case GGML_OP_OPT_STEP_SGD:
2519
2736
  ggml_cuda_opt_step_sgd(ctx, dst);
2520
2737
  break;
2738
+ case GGML_OP_SOLVE_TRI:
2739
+ ggml_cuda_op_solve_tri(ctx, dst);
2740
+ break;
2741
+ case GGML_OP_FILL:
2742
+ ggml_cuda_op_fill(ctx, dst);
2743
+ break;
2521
2744
  default:
2522
2745
  return false;
2523
2746
  }
@@ -2630,11 +2853,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
2630
2853
  }
2631
2854
 
2632
2855
  #ifdef USE_CUDA_GRAPH
2633
- static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2634
- bool use_cuda_graph) {
2856
+ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
2635
2857
 
2858
+ bool use_cuda_graph = true;
2636
2859
  // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2637
- cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
2638
2860
 
2639
2861
  const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
2640
2862
  const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
@@ -2685,118 +2907,105 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2685
2907
  #endif
2686
2908
  }
2687
2909
 
2688
- if (node->op == GGML_OP_CPY) {
2689
-
2690
- // Store the pointers which are updated for each token, such that these can be sent
2691
- // to the device and accessed using indirection from CUDA graph
2692
- cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
2693
-
2694
- // store a pointer to each copy op CUDA kernel to identify it later
2695
- void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2696
- if (!ptr) {
2697
- use_cuda_graph = false;
2698
- #ifndef NDEBUG
2699
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
2700
- #endif
2701
- }
2702
- }
2703
-
2704
2910
  if (!use_cuda_graph) {
2705
2911
  break;
2706
2912
  }
2707
2913
  }
2708
2914
 
2709
- if (use_cuda_graph) {
2710
- cuda_ctx->cuda_graph->use_cpy_indirection = true;
2711
- // copy pointers to GPU so they can be accessed via indirection within CUDA graph
2712
- ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
2713
- }
2714
-
2715
2915
  return use_cuda_graph;
2716
2916
  }
2717
2917
 
2718
- static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2719
- graph_node_properties->node_address = node->data;
2720
- graph_node_properties->node_op = node->op;
2918
+ static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
2919
+ props->node_address = node->data;
2920
+ props->node_op = node->op;
2721
2921
  for (int i = 0; i < GGML_MAX_DIMS; i++) {
2722
- graph_node_properties->ne[i] = node->ne[i];
2723
- graph_node_properties->nb[i] = node->nb[i];
2922
+ props->ne[i] = node->ne[i];
2923
+ props->nb[i] = node->nb[i];
2724
2924
  }
2725
2925
  for (int i = 0; i < GGML_MAX_SRC; i++) {
2726
- graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
2926
+ props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
2727
2927
  }
2728
- memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
2928
+ memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
2729
2929
  }
2730
2930
 
2731
- static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2732
- if (node->data != graph_node_properties->node_address &&
2733
- node->op != GGML_OP_CPY &&
2931
+ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
2932
+ if (node->data != props->node_address &&
2734
2933
  node->op != GGML_OP_VIEW) {
2735
2934
  return false;
2736
2935
  }
2737
2936
 
2738
- if (node->op != graph_node_properties->node_op) {
2937
+ if (node->op != props->node_op) {
2739
2938
  return false;
2740
2939
  }
2741
2940
 
2742
2941
  for (int i = 0; i < GGML_MAX_DIMS; i++) {
2743
- if (node->ne[i] != graph_node_properties->ne[i]) {
2942
+ if (node->ne[i] != props->ne[i]) {
2744
2943
  return false;
2745
2944
  }
2746
- if (node->nb[i] != graph_node_properties->nb[i]) {
2945
+ if (node->nb[i] != props->nb[i]) {
2747
2946
  return false;
2748
2947
  }
2749
2948
  }
2750
2949
 
2751
2950
  for (int i = 0; i < GGML_MAX_SRC; i++) {
2752
2951
  if (node->src[i] &&
2753
- node->src[i]->data != graph_node_properties->src_address[i] &&
2754
- node->op != GGML_OP_CPY &&
2952
+ node->src[i]->data != props->src_address[i] &&
2755
2953
  node->op != GGML_OP_VIEW
2756
2954
  ) {
2757
2955
  return false;
2758
2956
  }
2759
2957
  }
2760
2958
 
2761
- if (node->op == GGML_OP_SCALE &&
2762
- memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2959
+ if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
2960
+ memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2763
2961
  return false;
2764
2962
  }
2765
2963
 
2766
2964
  return true;
2767
2965
  }
2768
2966
 
2769
- static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
2967
+ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
2770
2968
 
2771
- bool cuda_graph_update_required = false;
2969
+ bool res = false;
2772
2970
 
2773
2971
  if (cuda_ctx->cuda_graph->instance == nullptr) {
2774
- cuda_graph_update_required = true;
2972
+ res = true;
2775
2973
  }
2776
2974
 
2777
2975
  // Check if the graph size has changed
2778
- if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
2779
- cuda_graph_update_required = true;
2780
- cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2976
+ if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
2977
+ res = true;
2978
+ cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
2781
2979
  }
2782
2980
 
2783
2981
  // Loop over nodes in GGML graph to determine if CUDA graph update is required
2784
2982
  // and store properties to allow this comparison for the next token
2785
2983
  for (int i = 0; i < cgraph->n_nodes; i++) {
2786
- bool has_matching_properties = true;
2787
- if (!cuda_graph_update_required) {
2788
- has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2984
+ bool props_match = true;
2985
+ if (!res) {
2986
+ props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
2789
2987
  }
2790
- if (!has_matching_properties) {
2791
- cuda_graph_update_required = true;
2988
+ if (!props_match) {
2989
+ res = true;
2792
2990
  }
2793
- set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2991
+ ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
2794
2992
  }
2795
2993
 
2796
- return cuda_graph_update_required;
2994
+ for (int i = 0; i < cgraph->n_leafs; i++) {
2995
+ bool props_match= true;
2996
+ if (!res) {
2997
+ props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
2998
+ }
2999
+ if (!props_match) {
3000
+ res = true;
3001
+ }
3002
+ ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
3003
+ }
3004
+
3005
+ return res;
2797
3006
  }
2798
3007
 
2799
- static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
3008
+ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
2800
3009
 
2801
3010
  #if CUDART_VERSION >= 12000
2802
3011
  cudaGraphExecUpdateResultInfo result_info;
@@ -2824,6 +3033,40 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
2824
3033
  }
2825
3034
  #endif
2826
3035
 
3036
+ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
3037
+ const ggml_tensor * view,
3038
+ const ggml_tensor * set_rows) {
3039
+
3040
+ if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) {
3041
+ return false;
3042
+ }
3043
+ // ne3 not tested
3044
+ if (rope->src[0]->ne[3] != 1) {
3045
+ return false;
3046
+ }
3047
+
3048
+ if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
3049
+ return false;
3050
+ }
3051
+
3052
+ if (set_rows->src[1]->type != GGML_TYPE_I64) {
3053
+ return false;
3054
+ }
3055
+
3056
+ // The view should flatten two dims of rope into one dim
3057
+ if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) {
3058
+ return false;
3059
+ }
3060
+
3061
+ // Only norm/neox shaders have the fusion code
3062
+ const int mode = ((const int32_t *) rope->op_params)[2];
3063
+ if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
3064
+ return false;
3065
+ }
3066
+
3067
+ return true;
3068
+ }
3069
+
2827
3070
  static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
2828
3071
  #ifndef NDEBUG
2829
3072
  const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
@@ -2831,39 +3074,94 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
2831
3074
  #endif
2832
3075
 
2833
3076
  //TODO: remove special case once ggml_can_fuse can handle empty nodes
2834
- std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
2835
- std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
3077
+ std::initializer_list<enum ggml_op> topk_moe_ops =
3078
+ ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
3079
+ std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
3080
+ ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
3081
+ std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
3082
+ ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
3083
+
3084
+ const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
3085
+ const std::initializer_list<enum ggml_op> & list2) {
3086
+ return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
3087
+ };
2836
3088
 
2837
- if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
3089
+ if (is_equal(topk_moe_ops_with_norm, ops) &&
3090
+ ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
3091
+ ggml_tensor * softmax = cgraph->nodes[node_idx];
3092
+ ggml_tensor * weights = cgraph->nodes[node_idx + 9];
3093
+ ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
3094
+ ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
3095
+ int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
2838
3096
 
2839
- if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
2840
- return false;
3097
+ if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
3098
+ return true;
2841
3099
  }
3100
+ }
2842
3101
 
2843
- for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
2844
- if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
2845
- }
3102
+ if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
2846
3103
  ggml_tensor * softmax = cgraph->nodes[node_idx];
2847
- ggml_tensor * weights = cgraph->nodes[node_idx+8];
3104
+ ggml_tensor * weights = cgraph->nodes[node_idx + 4];
3105
+ ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
3106
+ ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
3107
+ int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
2848
3108
 
2849
- if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
3109
+ if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
2850
3110
  return true;
2851
3111
  }
2852
3112
  }
2853
3113
 
2854
- if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
3114
+ if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
3115
+ ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
3116
+ ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
3117
+ ggml_tensor * weights = cgraph->nodes[node_idx + 5];
3118
+ ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
3119
+ ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
3120
+ int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
2855
3121
 
2856
- if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
2857
- return false;
3122
+ if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
3123
+ return true;
2858
3124
  }
3125
+ }
3126
+
3127
+ std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
3128
+ std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
3129
+
3130
+ std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
3131
+ std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
2859
3132
 
2860
- for (size_t i = 0; i < topk_moe_ops.size(); i++) {
2861
- if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
3133
+ if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&
3134
+ ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {
3135
+ const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
3136
+ const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
3137
+ const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
3138
+ const ggml_tensor * ffn_up_bias = cgraph->nodes[node_idx + 3];
3139
+ const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
3140
+
3141
+ if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
3142
+ return true;
2862
3143
  }
3144
+ }
2863
3145
 
2864
- ggml_tensor * softmax = cgraph->nodes[node_idx];
2865
- ggml_tensor * weights = cgraph->nodes[node_idx+4];
2866
- if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
3146
+ if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&
3147
+ ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
3148
+ const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
3149
+ const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
3150
+ const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
3151
+
3152
+ if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
3153
+ return true;
3154
+ }
3155
+ }
3156
+
3157
+ std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
3158
+
3159
+ if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
3160
+ const ggml_tensor * rope = cgraph->nodes[node_idx];
3161
+ const ggml_tensor * view = cgraph->nodes[node_idx + 1];
3162
+ const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
3163
+
3164
+ if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) {
2867
3165
  return true;
2868
3166
  }
2869
3167
  }
@@ -2898,7 +3196,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
2898
3196
  }
2899
3197
 
2900
3198
  //if rms norm is the B operand, then we don't handle broadcast
2901
- if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
3199
+ if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
2902
3200
  return false;
2903
3201
  }
2904
3202
 
@@ -2938,42 +3236,192 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
2938
3236
  return false;
2939
3237
  }
2940
3238
 
2941
- static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2942
- bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
3239
+ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
3240
+ bool graph_evaluated_or_captured = false;
3241
+
2943
3242
  // flag used to determine whether it is an integrated_gpu
2944
- const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
3243
+ const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
3244
+
3245
+ ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
3246
+ bool is_concurrent_event_active = false;
3247
+ ggml_cuda_concurrent_event * concurrent_event = nullptr;
3248
+ bool should_launch_concurrent_events = false;
3249
+
3250
+ const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
3251
+ if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
3252
+ concurrent_event = &stream_ctx.concurrent_events[node];
3253
+
3254
+ is_concurrent_event_active = true;
3255
+
3256
+ GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
3257
+
3258
+ cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
3259
+ GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
3260
+ CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
3261
+
3262
+ for (int i = 1; i <= concurrent_event->n_streams; ++i) {
3263
+ cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
3264
+ CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
3265
+ }
3266
+ }
3267
+ };
2945
3268
 
2946
3269
  while (!graph_evaluated_or_captured) {
2947
3270
  // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2948
3271
  // With the use of CUDA graphs, the execution will be performed by the graph launch.
2949
3272
  if (!use_cuda_graph || cuda_graph_update_required) {
3273
+ [[maybe_unused]] int prev_i = 0;
3274
+
3275
+ if (stream_ctx.concurrent_events.size() > 0) {
3276
+ should_launch_concurrent_events = true;
3277
+ for (const auto & [tensor, event] : stream_ctx.concurrent_events) {
3278
+ should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
3279
+ }
3280
+ }
3281
+
3282
+ if (should_launch_concurrent_events) {
3283
+ // Restore original node order within each concurrent region to enable fusion within streams
3284
+
3285
+ std::unordered_map<const ggml_tensor *, int> node_to_idx;
3286
+ node_to_idx.reserve(cgraph->n_nodes);
3287
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
3288
+ node_to_idx[cgraph->nodes[i]] = i;
3289
+ }
3290
+
3291
+ for (auto & [fork_node, event] : stream_ctx.concurrent_events) {
3292
+ // Find positions of all nodes from this event in the current graph
3293
+ std::vector<int> positions;
3294
+ positions.reserve(event.original_order.size());
3295
+
3296
+ bool all_found = true;
3297
+ for (const ggml_tensor * orig_node : event.original_order) {
3298
+ auto it = node_to_idx.find(orig_node);
3299
+ if (it != node_to_idx.end()) {
3300
+ positions.push_back(it->second);
3301
+ } else {
3302
+ all_found = false;
3303
+ break;
3304
+ }
3305
+ }
3306
+
3307
+ if (!all_found || positions.size() != event.original_order.size()) {
3308
+ continue;
3309
+ }
3310
+
3311
+ // Sort positions to get contiguous range
3312
+ std::vector<int> sorted_positions = positions;
3313
+ std::sort(sorted_positions.begin(), sorted_positions.end());
3314
+
3315
+ bool is_contiguous = true;
3316
+ for (size_t i = 1; i < sorted_positions.size(); ++i) {
3317
+ if (sorted_positions[i] != sorted_positions[i-1] + 1) {
3318
+ is_contiguous = false;
3319
+ break;
3320
+ }
3321
+ }
3322
+
3323
+ if (!is_contiguous) {
3324
+ continue;
3325
+ }
3326
+
3327
+ // Restore original order at the sorted positions
3328
+ int start_pos = sorted_positions[0];
3329
+ for (size_t i = 0; i < event.original_order.size(); ++i) {
3330
+ cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
3331
+ }
3332
+ }
3333
+ } else {
3334
+ stream_ctx.concurrent_events.clear();
3335
+ }
2950
3336
 
2951
3337
  for (int i = 0; i < cgraph->n_nodes; i++) {
2952
3338
  ggml_tensor * node = cgraph->nodes[i];
3339
+ if (is_concurrent_event_active) {
3340
+ GGML_ASSERT(concurrent_event);
3341
+
3342
+ if (node == concurrent_event->join_node) {
3343
+ cuda_ctx->curr_stream_no = 0;
3344
+ for (int i = 1; i <= concurrent_event->n_streams; ++i) {
3345
+ // Wait on join events of forked streams in the main stream
3346
+ CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1],
3347
+ cuda_ctx->stream(cuda_ctx->device, i)));
3348
+ CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1]));
3349
+ }
3350
+
3351
+ is_concurrent_event_active = false;
3352
+ concurrent_event = nullptr;
3353
+ } else {
3354
+ GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end());
3355
+ cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
3356
+ GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
3357
+ }
3358
+ } else if (i - prev_i > 1) {
3359
+ //the previous node was fused
3360
+ const ggml_tensor * prev_node = cgraph->nodes[i - 1];
3361
+ try_launch_concurrent_event(prev_node);
3362
+
3363
+ if (is_concurrent_event_active) {
3364
+ cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
3365
+ GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
3366
+ }
3367
+ }
3368
+
3369
+ #ifdef GGML_CUDA_DEBUG
3370
+ const int nodes_fused = i - prev_i - 1;
3371
+ if (nodes_fused > 0) {
3372
+ GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
3373
+ }
3374
+ #endif
3375
+ prev_i = i;
2953
3376
 
2954
3377
  if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2955
3378
  continue;
2956
3379
  }
2957
3380
 
3381
+
3382
+ // start of fusion operations
2958
3383
  static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
2959
3384
  if (!disable_fusion) {
2960
3385
 
2961
3386
  if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
2962
- ggml_tensor * weights = cgraph->nodes[i+8];
2963
- ggml_tensor * selected_experts = cgraph->nodes[i+3];
2964
- ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
2965
- i += 8;
3387
+ ggml_tensor * weights = cgraph->nodes[i + 9];
3388
+ ggml_tensor * selected_experts = cgraph->nodes[i + 3];
3389
+ ggml_tensor * clamp = cgraph->nodes[i + 7];
3390
+ ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
3391
+ /*delayed softmax*/ false, clamp);
3392
+ i += 9;
2966
3393
  continue;
2967
3394
  }
2968
3395
 
2969
3396
  if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
2970
- ggml_tensor * weights = cgraph->nodes[i+4];
2971
- ggml_tensor * selected_experts = cgraph->nodes[i+3];
2972
- ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
3397
+ ggml_tensor * weights = cgraph->nodes[i + 4];
3398
+ ggml_tensor * selected_experts = cgraph->nodes[i + 3];
3399
+ ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
3400
+ /*delayed softmax*/ false);
2973
3401
  i += 4;
2974
3402
  continue;
2975
3403
  }
2976
3404
 
3405
+ if (ggml_cuda_can_fuse(cgraph, i,
3406
+ ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
3407
+ ggml_tensor * weights = cgraph->nodes[i + 5];
3408
+ ggml_tensor * ids = cgraph->nodes[i + 1];
3409
+
3410
+ ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
3411
+ /*delayed_softmax*/ true);
3412
+ i += 5;
3413
+ continue;
3414
+ }
3415
+
3416
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3417
+ ggml_tensor * rope = cgraph->nodes[i];
3418
+ ggml_tensor * set_rows = cgraph->nodes[i + 2];
3419
+
3420
+ ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
3421
+ i += 2;
3422
+ continue;
3423
+ }
3424
+
2977
3425
  if (node->op == GGML_OP_ADD) {
2978
3426
  int n_fuse = 0;
2979
3427
  ggml_op ops[8];
@@ -3005,6 +3453,195 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
3005
3453
  }
3006
3454
  }
3007
3455
 
3456
+ bool fused_mul_mat_vec = false;
3457
+ int fused_node_count = 0;
3458
+
3459
+ for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3460
+ const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
3461
+
3462
+ if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
3463
+ ggml_tensor * glu = cgraph->nodes[i + 4];
3464
+ ggml_tensor * gate_bias_n = glu->src[0];
3465
+ ggml_tensor * up_bias_n = glu->src[1];
3466
+
3467
+ //we don't assume the order for {gate, up}. Instead infer it from the bias tensor
3468
+ ggml_tensor * gate_n = nullptr;
3469
+ ggml_tensor * up_n = nullptr;
3470
+
3471
+ if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
3472
+ gate_n = cgraph->nodes[i];
3473
+ up_n = cgraph->nodes[i + 2];
3474
+ } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
3475
+ gate_n = cgraph->nodes[i + 2];
3476
+ up_n = cgraph->nodes[i];
3477
+ } else {
3478
+ continue;
3479
+ }
3480
+
3481
+ auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
3482
+ if (op_bias == GGML_OP_ADD) {
3483
+ if (bias_node->src[0] == mul_node) {
3484
+ return bias_node->src[1];
3485
+ }
3486
+ if (bias_node->src[1] == mul_node) {
3487
+ return bias_node->src[0];
3488
+ }
3489
+ return (ggml_tensor *) nullptr;
3490
+ }
3491
+ GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
3492
+ GGML_ASSERT(bias_node->src[0] == mul_node);
3493
+ return bias_node->src[1];
3494
+ };
3495
+
3496
+ ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
3497
+ ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
3498
+
3499
+ if (!up_bias_tensor || !gate_bias_tensor) {
3500
+ continue;
3501
+ }
3502
+
3503
+ // we don't support repeating adds
3504
+ if (bias_op == GGML_OP_ADD &&
3505
+ (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
3506
+ !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
3507
+ continue;
3508
+ }
3509
+
3510
+ const ggml_tensor * src0 = up_n->src[0];
3511
+ const ggml_tensor * src1 = up_n->src[1];
3512
+ const ggml_tensor * ids = up_n->src[2];
3513
+
3514
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
3515
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3516
+ fusion_data.gate = gate_n->src[0];
3517
+ fusion_data.x_bias = up_bias_tensor;
3518
+ fusion_data.gate_bias = gate_bias_tensor;
3519
+ fusion_data.glu_op = ggml_get_glu_op(glu);
3520
+
3521
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3522
+ fused_mul_mat_vec = true;
3523
+ fused_node_count = 5;
3524
+ break;
3525
+ }
3526
+
3527
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
3528
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3529
+ fusion_data.gate = gate_n->src[0];
3530
+ fusion_data.x_bias = up_bias_tensor;
3531
+ fusion_data.gate_bias = gate_bias_tensor;
3532
+ fusion_data.glu_op = ggml_get_glu_op(glu);
3533
+
3534
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3535
+ fused_mul_mat_vec = true;
3536
+ fused_node_count = 5;
3537
+ break;
3538
+ }
3539
+ } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
3540
+ ggml_tensor * glu = cgraph->nodes[i + 2];
3541
+ ggml_tensor * gate = glu->src[0];
3542
+ ggml_tensor * up = glu->src[1];
3543
+
3544
+ bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
3545
+ || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
3546
+
3547
+ if (!ok) continue;
3548
+
3549
+ const ggml_tensor * src0 = up->src[0];
3550
+ const ggml_tensor * src1 = up->src[1];
3551
+ const ggml_tensor * ids = up->src[2];
3552
+
3553
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
3554
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3555
+ fusion_data.gate = gate->src[0];
3556
+ fusion_data.glu_op = ggml_get_glu_op(glu);
3557
+
3558
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3559
+ fused_mul_mat_vec = true;
3560
+ fused_node_count = 3;
3561
+ break;
3562
+ }
3563
+
3564
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
3565
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3566
+ fusion_data.gate = gate->src[0];
3567
+ fusion_data.glu_op = ggml_get_glu_op(glu);
3568
+
3569
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3570
+ fused_mul_mat_vec = true;
3571
+ fused_node_count = 3;
3572
+ break;
3573
+ }
3574
+ }
3575
+ }
3576
+
3577
+ if (fused_mul_mat_vec) {
3578
+ i += fused_node_count - 1;
3579
+ continue;
3580
+ }
3581
+
3582
+ fused_mul_mat_vec = false;
3583
+ fused_node_count = 0;
3584
+
3585
+ for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3586
+ const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
3587
+
3588
+ if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
3589
+ continue;
3590
+ }
3591
+
3592
+ ggml_tensor * mm_node = cgraph->nodes[i];
3593
+ ggml_tensor * bias_node = cgraph->nodes[i + 1];
3594
+
3595
+ ggml_tensor * bias_tensor = nullptr;
3596
+ if (bias_op == GGML_OP_ADD) {
3597
+ if (bias_node->src[0] == mm_node) {
3598
+ bias_tensor = bias_node->src[1];
3599
+ } else if (bias_node->src[1] == mm_node) {
3600
+ bias_tensor = bias_node->src[0];
3601
+ } else {
3602
+ continue;
3603
+ }
3604
+ } else {
3605
+ if (bias_node->src[0] != mm_node) {
3606
+ continue;
3607
+ }
3608
+ bias_tensor = bias_node->src[1];
3609
+ }
3610
+
3611
+ const ggml_tensor * src0 = mm_node->src[0];
3612
+ const ggml_tensor * src1 = mm_node->src[1];
3613
+ const ggml_tensor * ids = mm_node->src[2];
3614
+
3615
+ if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
3616
+ continue;
3617
+ }
3618
+
3619
+ if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
3620
+ continue;
3621
+ }
3622
+
3623
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3624
+ fusion_data.x_bias = bias_tensor;
3625
+
3626
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
3627
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
3628
+ fused_mul_mat_vec = true;
3629
+ fused_node_count = 2;
3630
+ break;
3631
+ }
3632
+
3633
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
3634
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
3635
+ fused_mul_mat_vec = true;
3636
+ fused_node_count = 2;
3637
+ break;
3638
+ }
3639
+ }
3640
+
3641
+ if (fused_mul_mat_vec) {
3642
+ i += fused_node_count - 1;
3643
+ continue;
3644
+ }
3008
3645
 
3009
3646
  if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
3010
3647
  ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
@@ -3035,13 +3672,17 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
3035
3672
  }
3036
3673
  #else
3037
3674
  GGML_UNUSED(integrated);
3038
- #endif // NDEBUG
3675
+ #endif // NDEBUG
3039
3676
 
3040
3677
  bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
3041
3678
  if (!ok) {
3042
3679
  GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
3043
3680
  }
3044
3681
  GGML_ASSERT(ok);
3682
+
3683
+ if (!is_concurrent_event_active) {
3684
+ try_launch_concurrent_event(node);
3685
+ }
3045
3686
  }
3046
3687
  }
3047
3688
 
@@ -3069,7 +3710,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
3069
3710
  CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
3070
3711
  }
3071
3712
  if (cuda_graph_update_required) { // Update graph executable
3072
- update_cuda_graph_executable(cuda_ctx);
3713
+ ggml_cuda_graph_update_executable(cuda_ctx);
3073
3714
  }
3074
3715
  // Launch graph
3075
3716
  CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
@@ -3079,60 +3720,46 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
3079
3720
  }
3080
3721
  }
3081
3722
 
3082
- static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
3083
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
3084
-
3085
- ggml_cuda_set_device(cuda_ctx->device);
3723
+ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
3086
3724
 
3087
3725
  #ifdef USE_CUDA_GRAPH
3088
- static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
3089
3726
 
3090
- // Objects required for CUDA Graph
3091
3727
  if (cuda_ctx->cuda_graph == nullptr) {
3092
3728
  cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
3093
3729
  }
3094
3730
 
3095
- bool use_cuda_graph = true;
3096
- bool cuda_graph_update_required = false;
3097
-
3098
3731
  if (cuda_ctx->cuda_graph->graph == nullptr) {
3099
3732
  if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
3100
3733
  cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
3101
- #ifndef NDEBUG
3102
3734
  GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
3103
- #endif
3104
3735
  }
3105
3736
  }
3106
3737
 
3107
- // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
3108
- // or previous graph capture failure.
3109
- // Also disable for multi-gpu for now. TO DO investigate
3110
- if (disable_cuda_graphs_due_to_env
3111
- || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
3112
- || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
3113
- || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
3114
- use_cuda_graph = false;
3115
- }
3738
+ return cuda_ctx->cuda_graph->is_enabled();
3739
+ #else
3740
+ GGML_UNUSED(cuda_ctx);
3741
+ return false;
3742
+ #endif // USE_CUDA_GRAPH
3743
+ }
3116
3744
 
3117
- if (use_cuda_graph) {
3118
- cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
3745
+ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
3746
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
3119
3747
 
3120
- use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
3748
+ ggml_cuda_set_device(cuda_ctx->device);
3121
3749
 
3122
- // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
3123
- if (use_cuda_graph && cuda_graph_update_required) {
3124
- cuda_ctx->cuda_graph->number_consecutive_updates++;
3125
- } else {
3126
- cuda_ctx->cuda_graph->number_consecutive_updates = 0;
3127
- }
3750
+ bool use_cuda_graph = false;
3751
+ bool cuda_graph_update_required = false;
3128
3752
 
3129
- if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
3130
- cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
3131
- #ifndef NDEBUG
3132
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
3133
- #endif
3134
- }
3753
+ #ifdef USE_CUDA_GRAPH
3754
+ use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
3755
+
3756
+ if (cuda_ctx->cuda_graph->is_enabled()) {
3757
+ cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
3758
+ use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
3759
+
3760
+ cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
3135
3761
  }
3762
+ #endif // USE_CUDA_GRAPH
3136
3763
 
3137
3764
  if (use_cuda_graph && cuda_graph_update_required) {
3138
3765
  // Start CUDA graph capture
@@ -3144,18 +3771,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
3144
3771
  CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
3145
3772
  }
3146
3773
 
3147
- if (!use_cuda_graph) {
3148
- cuda_ctx->cuda_graph->use_cpy_indirection = false;
3149
- }
3150
-
3151
- #else
3152
- bool use_cuda_graph = false;
3153
- bool cuda_graph_update_required = false;
3154
- #endif // USE_CUDA_GRAPH
3155
-
3156
- bool graph_evaluated_or_captured = false;
3157
-
3158
- evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
3774
+ ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
3159
3775
 
3160
3776
  return GGML_STATUS_SUCCESS;
3161
3777
  }
@@ -3185,6 +3801,243 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
3185
3801
  }
3186
3802
  }
3187
3803
 
3804
+ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
3805
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
3806
+
3807
+ const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
3808
+
3809
+ static bool enable_graph_optimization = [] {
3810
+ const char * env = getenv("GGML_CUDA_GRAPH_OPT");
3811
+ return env != nullptr && atoi(env) == 1;
3812
+ }();
3813
+
3814
+ if (!enable_graph_optimization) {
3815
+ return;
3816
+ }
3817
+
3818
+ ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
3819
+ stream_context.reset();
3820
+
3821
+ if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) {
3822
+ return;
3823
+ }
3824
+
3825
+ // number of out-degrees for a particular node
3826
+ std::unordered_map<const ggml_tensor *, int> fan_out;
3827
+ // reverse mapping of node to index in the cgraph
3828
+ std::unordered_map<const ggml_tensor *, int> node_indices;
3829
+
3830
+ const auto & is_noop = [](const ggml_tensor * node) -> bool {
3831
+ return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE ||
3832
+ node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
3833
+ };
3834
+
3835
+ const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool {
3836
+ for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
3837
+ if (dst->src[s] == src) {
3838
+ return true;
3839
+ }
3840
+ }
3841
+ // implicit dependency if they view the same tensor
3842
+ const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst;
3843
+ const ggml_tensor * src2 = src->view_src ? src->view_src : src;
3844
+ if (dst2 == src2) {
3845
+ return true;
3846
+ }
3847
+ return false;
3848
+ };
3849
+
3850
+ for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
3851
+ const ggml_tensor * node = cgraph->nodes[node_idx];
3852
+ node_indices[node] = node_idx;
3853
+
3854
+ if (is_noop(node)) {
3855
+ continue;
3856
+ }
3857
+ for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3858
+ const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx];
3859
+ //TODO: check why nrows > 1 fails
3860
+ if (node && !is_noop(node) && ggml_nrows(node) <= 1) {
3861
+ fan_out[src] += 1;
3862
+ }
3863
+ }
3864
+ }
3865
+
3866
+ // Target Q, K, V for concurrency
3867
+ // this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else):
3868
+ // 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm")
3869
+ // 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn")
3870
+ // 3. account for all branches from the fork to the join
3871
+ // 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details)
3872
+ // 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams
3873
+ // See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030
3874
+
3875
+ const int min_fan_out = 3;
3876
+ const int max_fan_out = 3;
3877
+
3878
+ // store {fork_idx, join_idx}
3879
+ std::vector<std::pair<int, int>> concurrent_node_ranges;
3880
+
3881
+ for (const auto & [root_node, count] : fan_out) {
3882
+ if (count >= min_fan_out && count <= max_fan_out) {
3883
+ const int root_node_idx = node_indices[root_node];
3884
+
3885
+ // only optimize for attn_norm
3886
+ // TODO: make this more generic
3887
+ if (!strstr(root_node->name, "attn_norm")) {
3888
+ continue;
3889
+ }
3890
+
3891
+ bool is_part_of_event = false;
3892
+ for (const auto & [start, end] : concurrent_node_ranges) {
3893
+ if (root_node_idx >= start && root_node_idx <= end) {
3894
+ is_part_of_event = true;
3895
+ }
3896
+ }
3897
+
3898
+ if (is_part_of_event) {
3899
+ continue;
3900
+ }
3901
+
3902
+ std::vector<std::vector<const ggml_tensor *>> nodes_per_branch;
3903
+ for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
3904
+ const ggml_tensor * node = cgraph->nodes[i];
3905
+ if (!is_noop(node) && depends_on(node, root_node)) {
3906
+ nodes_per_branch.push_back({ node });
3907
+ }
3908
+ }
3909
+
3910
+ GGML_ASSERT(nodes_per_branch.size() == (size_t) count);
3911
+
3912
+ //find the join point
3913
+ const ggml_tensor * join_node = nullptr;
3914
+
3915
+ const auto & belongs_to_branch = [&](const ggml_tensor * node,
3916
+ const std::vector<const ggml_tensor *> & branch) -> bool {
3917
+ for (const ggml_tensor * n : branch) {
3918
+ if (depends_on(node, n)) {
3919
+ return true;
3920
+ }
3921
+ }
3922
+ return false;
3923
+ };
3924
+
3925
+ for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
3926
+ const ggml_tensor * curr_node = cgraph->nodes[i];
3927
+
3928
+ int num_joins = 0;
3929
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
3930
+ if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) {
3931
+ num_joins++;
3932
+ }
3933
+ }
3934
+
3935
+ if (num_joins >= 2) {
3936
+ join_node = curr_node;
3937
+ break;
3938
+ }
3939
+
3940
+ bool found_branch = false;
3941
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
3942
+ std::vector<const ggml_tensor *> & branch_vec = nodes_per_branch[branch_idx];
3943
+ if (belongs_to_branch(curr_node, branch_vec)) {
3944
+ //continue accumulating
3945
+ if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) {
3946
+ branch_vec.push_back(curr_node);
3947
+ }
3948
+ found_branch = true;
3949
+ }
3950
+ }
3951
+
3952
+ if (!found_branch && is_noop(curr_node)) {
3953
+ // we can put it in any branch because it will be ignored
3954
+ nodes_per_branch[0].push_back({ curr_node });
3955
+ }
3956
+ }
3957
+
3958
+ if (join_node) {
3959
+ //Create ggml_cuda_concurrent_event
3960
+ ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size());
3961
+ concurrent_event.join_node = join_node;
3962
+
3963
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
3964
+ for (const ggml_tensor * n : nodes_per_branch[branch_idx]) {
3965
+ concurrent_event.stream_mapping[n] = branch_idx + 1;
3966
+ }
3967
+ }
3968
+
3969
+ int fork_node_idx = node_indices[root_node];
3970
+ int join_node_idx = node_indices[join_node];
3971
+
3972
+ int current_branch_idx = 0;
3973
+ int current_node_idx = fork_node_idx + 1;
3974
+ const int n_branches = nodes_per_branch.size();
3975
+
3976
+ int total_branch_nodes = 0;
3977
+ for (std::vector<const ggml_tensor *> branch_nodes : nodes_per_branch) {
3978
+ total_branch_nodes += branch_nodes.size();
3979
+ }
3980
+
3981
+ // there are other nodes in the middle which are unaccounted for
3982
+ // usually (cpy) nodes, then ignore this fork
3983
+ if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) {
3984
+ GGML_LOG_DEBUG(
3985
+ "Skipping %s because the number of nodes in the middle is not equal to the total number of "
3986
+ "branch nodes %d != %d\n",
3987
+ root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes);
3988
+ continue;
3989
+ }
3990
+
3991
+ // Save the original order of nodes in this region before interleaving
3992
+ // This is used later to restore grouping for fusion within streams
3993
+ concurrent_event.original_order.reserve(total_branch_nodes);
3994
+ for (int i = fork_node_idx + 1; i < join_node_idx; ++i) {
3995
+ concurrent_event.original_order.push_back(cgraph->nodes[i]);
3996
+ }
3997
+
3998
+ std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
3999
+ GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
4000
+ concurrent_events.emplace(root_node, std::move(concurrent_event));
4001
+ GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node);
4002
+ concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx);
4003
+
4004
+ // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
4005
+ // example transformation:
4006
+ // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->
4007
+ // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]
4008
+ while (current_node_idx < join_node_idx) {
4009
+ std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];
4010
+
4011
+ bool has_node = false;
4012
+ for (std::vector<const ggml_tensor *> branch_node : nodes_per_branch) {
4013
+ has_node |= branch_node.size() > 0;
4014
+ }
4015
+
4016
+ GGML_ASSERT(has_node);
4017
+
4018
+ if (branch_nodes.empty()) {
4019
+ current_branch_idx = (current_branch_idx + 1) % n_branches;
4020
+ continue;
4021
+ }
4022
+
4023
+ cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
4024
+ current_node_idx++;
4025
+ branch_nodes.erase(branch_nodes.begin());
4026
+
4027
+ // append all empty nodes
4028
+ while (!branch_nodes.empty() && is_noop(branch_nodes.front())) {
4029
+ cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
4030
+ current_node_idx++;
4031
+ branch_nodes.erase(branch_nodes.begin());
4032
+ }
4033
+
4034
+ current_branch_idx = (current_branch_idx + 1) % n_branches;
4035
+ }
4036
+ }
4037
+ }
4038
+ }
4039
+ }
4040
+
3188
4041
  static const ggml_backend_i ggml_backend_cuda_interface = {
3189
4042
  /* .get_name = */ ggml_backend_cuda_get_name,
3190
4043
  /* .free = */ ggml_backend_cuda_free,
@@ -3199,7 +4052,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
3199
4052
  /* .graph_compute = */ ggml_backend_cuda_graph_compute,
3200
4053
  /* .event_record = */ ggml_backend_cuda_event_record,
3201
4054
  /* .event_wait = */ ggml_backend_cuda_event_wait,
3202
- /* .graph_optimize = */ NULL,
4055
+ /* .graph_optimize = */ ggml_backend_cuda_graph_optimize,
3203
4056
  };
3204
4057
 
3205
4058
  static ggml_guid_t ggml_backend_cuda_guid() {
@@ -3270,6 +4123,7 @@ struct ggml_backend_cuda_device_context {
3270
4123
  std::string name;
3271
4124
  std::string description;
3272
4125
  std::string pci_bus_id;
4126
+ int op_offload_min_batch_size;
3273
4127
  };
3274
4128
 
3275
4129
  static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -3282,10 +4136,110 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
3282
4136
  return ctx->description.c_str();
3283
4137
  }
3284
4138
 
4139
+ #if defined(__linux__)
4140
+ // Helper function to get available memory from /proc/meminfo for UMA systems
4141
+ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) {
4142
+ FILE * meminfo_file = nullptr;
4143
+ // 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough
4144
+ const size_t BUFFER_SIZE = 2048;
4145
+ auto file_buffer = std::make_unique<char[]>(BUFFER_SIZE);
4146
+ size_t bytes_read = 0;
4147
+ long huge_tlb_total_pages = -1;
4148
+ long huge_tlb_free_pages = -1;
4149
+ long huge_tlb_page_size = -1;
4150
+
4151
+ if (available_memory_kb == nullptr || free_swap_kb == nullptr) {
4152
+ return false;
4153
+ }
4154
+
4155
+ meminfo_file = fopen("/proc/meminfo", "r");
4156
+ if (meminfo_file == nullptr) {
4157
+ GGML_LOG_ERROR("%s: failed to open /proc/meminfo\n", __func__);
4158
+ return false;
4159
+ }
4160
+
4161
+ // Read file into buffer
4162
+ bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file);
4163
+ fclose(meminfo_file);
4164
+
4165
+ if (bytes_read == 0) {
4166
+ GGML_LOG_ERROR("%s: failed to read from /proc/meminfo\n", __func__);
4167
+ return false;
4168
+ }
4169
+ file_buffer[bytes_read] = '\0';
4170
+
4171
+ *available_memory_kb = -1;
4172
+ *free_swap_kb = -1;
4173
+
4174
+ // Parse the file buffer line by line
4175
+ char * line = file_buffer.get();
4176
+ char * line_next;
4177
+ while (line < file_buffer.get() + bytes_read) {
4178
+ // Find the end of the current line
4179
+ line_next = strchr(line, '\n');
4180
+ if (line_next != nullptr) {
4181
+ *line_next = '\0';
4182
+ line_next++;
4183
+ } else {
4184
+ line_next = file_buffer.get() + bytes_read;
4185
+ }
4186
+
4187
+ long value;
4188
+ if (sscanf(line, "MemAvailable: %ld kB", &value) == 1) {
4189
+ *available_memory_kb = value;
4190
+ } else if (sscanf(line, "SwapFree: %ld kB", &value) == 1) {
4191
+ *free_swap_kb = value;
4192
+ } else if (sscanf(line, "HugePages_Total: %ld", &value) == 1) {
4193
+ huge_tlb_total_pages = value;
4194
+ } else if (sscanf(line, "HugePages_Free: %ld", &value) == 1) {
4195
+ huge_tlb_free_pages = value;
4196
+ } else if (sscanf(line, "Hugepagesize: %ld kB", &value) == 1) {
4197
+ huge_tlb_page_size = value;
4198
+ }
4199
+
4200
+ line = line_next;
4201
+ }
4202
+
4203
+ if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) {
4204
+ *available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size;
4205
+
4206
+ // Hugetlbfs pages are not swappable.
4207
+ *free_swap_kb = 0;
4208
+ }
4209
+
4210
+ GGML_LOG_DEBUG("%s: final available_memory_kb: %ld\n", __func__, *available_memory_kb);
4211
+ return true;
4212
+ }
4213
+ #endif // defined(__linux__)
4214
+
3285
4215
  static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
3286
4216
  ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
3287
4217
  ggml_cuda_set_device(ctx->device);
3288
4218
  CUDA_CHECK(cudaMemGetInfo(free, total));
4219
+
4220
+ // ref: https://github.com/ggml-org/llama.cpp/pull/17368
4221
+ #if defined(__linux__)
4222
+ // Check if this is a UMA (Unified Memory Architecture) system
4223
+ cudaDeviceProp prop;
4224
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));
4225
+
4226
+ // Check if UMA is explicitly enabled via environment variable
4227
+ bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
4228
+ bool is_uma = prop.integrated > 0 || uma_env;
4229
+
4230
+ if (is_uma) {
4231
+ // For UMA systems (like DGX Spark), use system memory info
4232
+ long available_memory_kb = 0;
4233
+ long free_swap_kb = 0;
4234
+
4235
+ if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) {
4236
+ *free = (size_t)available_memory_kb * 1024;
4237
+ } else {
4238
+ GGML_LOG_ERROR("%s: /proc/meminfo reading failed, using cudaMemGetInfo\n", __func__);
4239
+ }
4240
+ }
4241
+ #endif // defined(__linux__)
4242
+
3289
4243
  }
3290
4244
 
3291
4245
  static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
@@ -3373,7 +4327,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3373
4327
  case GGML_UNARY_OP_GELU_QUICK:
3374
4328
  case GGML_UNARY_OP_TANH:
3375
4329
  case GGML_UNARY_OP_EXP:
4330
+ case GGML_UNARY_OP_EXPM1:
4331
+ case GGML_UNARY_OP_SOFTPLUS:
3376
4332
  case GGML_UNARY_OP_ELU:
4333
+ case GGML_UNARY_OP_XIELU:
4334
+ case GGML_UNARY_OP_FLOOR:
4335
+ case GGML_UNARY_OP_CEIL:
4336
+ case GGML_UNARY_OP_ROUND:
4337
+ case GGML_UNARY_OP_TRUNC:
3377
4338
  return ggml_is_contiguous(op->src[0]);
3378
4339
  default:
3379
4340
  return false;
@@ -3488,6 +4449,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3488
4449
  op->src[0]->type == GGML_TYPE_F32 &&
3489
4450
  (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
3490
4451
  } break;
4452
+ case GGML_OP_SET:
4453
+ {
4454
+ const ggml_type t = op->type;
4455
+ return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
4456
+ t == op->src[0]->type &&
4457
+ t == op->src[1]->type;
4458
+ } break;
3491
4459
  case GGML_OP_CPY:
3492
4460
  {
3493
4461
  ggml_type src0_type = op->src[0]->type;
@@ -3536,6 +4504,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3536
4504
  if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
3537
4505
  return true;
3538
4506
  }
4507
+ if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) {
4508
+ return true;
4509
+ }
3539
4510
  if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
3540
4511
  return true;
3541
4512
  }
@@ -3642,12 +4613,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3642
4613
  case GGML_OP_CONV_2D_DW:
3643
4614
  case GGML_OP_CONV_TRANSPOSE_2D:
3644
4615
  case GGML_OP_POOL_2D:
3645
- case GGML_OP_SUM:
3646
4616
  case GGML_OP_ACC:
3647
4617
  return true;
4618
+ case GGML_OP_SUM:
4619
+ return ggml_is_contiguous_rows(op->src[0]);
4620
+ case GGML_OP_TOP_K:
3648
4621
  case GGML_OP_ARGSORT:
3649
- // TODO: Support arbitrary column width
4622
+ #ifndef GGML_CUDA_USE_CUB
3650
4623
  return op->src[0]->ne[0] <= 1024;
4624
+ #else
4625
+ return true;
4626
+ #endif
3651
4627
  case GGML_OP_SUM_ROWS:
3652
4628
  case GGML_OP_MEAN:
3653
4629
  case GGML_OP_GROUP_NORM:
@@ -3668,7 +4644,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3668
4644
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
3669
4645
  case GGML_OP_OPT_STEP_ADAMW:
3670
4646
  case GGML_OP_OPT_STEP_SGD:
4647
+ case GGML_OP_FILL:
4648
+ case GGML_OP_CUMSUM:
4649
+ case GGML_OP_TRI:
4650
+ case GGML_OP_DIAG:
4651
+ case GGML_OP_SOLVE_TRI:
3671
4652
  return true;
4653
+
3672
4654
  default:
3673
4655
  return false;
3674
4656
  }
@@ -3696,11 +4678,9 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
3696
4678
  }
3697
4679
 
3698
4680
  static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
3699
- const int min_batch_size = 32;
3700
-
3701
- return get_op_batch_size(op) >= min_batch_size;
4681
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
3702
4682
 
3703
- GGML_UNUSED(dev);
4683
+ return get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;
3704
4684
  }
3705
4685
 
3706
4686
  static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) {
@@ -3811,6 +4791,16 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
3811
4791
  features.push_back({ "FA_ALL_QUANTS", "1" });
3812
4792
  #endif
3813
4793
 
4794
+ {
4795
+ const auto & info = ggml_cuda_info();
4796
+ for (int id = 0; id < info.device_count; ++id) {
4797
+ if (blackwell_mma_available(info.devices[id].cc)) {
4798
+ features.push_back({ "BLACKWELL_NATIVE_FP4", "1"});
4799
+ break;
4800
+ }
4801
+ }
4802
+ }
4803
+
3814
4804
  #undef _STRINGIFY
3815
4805
  #undef STRINGIFY
3816
4806
 
@@ -3858,13 +4848,13 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
3858
4848
  std::lock_guard<std::mutex> lock(mutex);
3859
4849
  if (!initialized) {
3860
4850
  ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
4851
+ const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
3861
4852
 
3862
4853
  for (int i = 0; i < ggml_cuda_info().device_count; i++) {
3863
4854
  ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
3864
4855
  dev_ctx->device = i;
3865
4856
  dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
3866
4857
 
3867
- ggml_cuda_set_device(i);
3868
4858
  cudaDeviceProp prop;
3869
4859
  CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
3870
4860
  dev_ctx->description = prop.name;
@@ -3872,6 +4862,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
3872
4862
  char pci_bus_id[16] = {};
3873
4863
  snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
3874
4864
  dev_ctx->pci_bus_id = pci_bus_id;
4865
+ dev_ctx->op_offload_min_batch_size = min_batch_size;
3875
4866
 
3876
4867
  ggml_backend_dev_t dev = new ggml_backend_device {
3877
4868
  /* .iface = */ ggml_backend_cuda_device_interface,