whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -19,7 +19,9 @@
19
19
  #include "ggml-cuda/count-equal.cuh"
20
20
  #include "ggml-cuda/cpy.cuh"
21
21
  #include "ggml-cuda/cross-entropy-loss.cuh"
22
+ #include "ggml-cuda/cumsum.cuh"
22
23
  #include "ggml-cuda/diagmask.cuh"
24
+ #include "ggml-cuda/diag.cuh"
23
25
  #include "ggml-cuda/fattn.cuh"
24
26
  #include "ggml-cuda/getrows.cuh"
25
27
  #include "ggml-cuda/im2col.cuh"
@@ -43,6 +45,7 @@
43
45
  #include "ggml-cuda/ssm-scan.cuh"
44
46
  #include "ggml-cuda/sum.cuh"
45
47
  #include "ggml-cuda/sumrows.cuh"
48
+ #include "ggml-cuda/top-k.cuh"
46
49
  #include "ggml-cuda/mean.cuh"
47
50
  #include "ggml-cuda/tsembd.cuh"
48
51
  #include "ggml-cuda/topk-moe.cuh"
@@ -50,8 +53,14 @@
50
53
  #include "ggml-cuda/upscale.cuh"
51
54
  #include "ggml-cuda/wkv.cuh"
52
55
  #include "ggml-cuda/gla.cuh"
56
+ #include "ggml-cuda/gated_delta_net.cuh"
57
+ #include "ggml-cuda/set.cuh"
53
58
  #include "ggml-cuda/set-rows.cuh"
54
59
  #include "ggml-cuda/pad_reflect_1d.cuh"
60
+ #include "ggml-cuda/solve_tri.cuh"
61
+ #include "ggml-cuda/tri.cuh"
62
+ #include "ggml-cuda/cumsum.cuh"
63
+ #include "ggml-cuda/fill.cuh"
55
64
  #include "ggml.h"
56
65
 
57
66
  #include <algorithm>
@@ -62,17 +71,18 @@
62
71
  #include <condition_variable>
63
72
  #include <cstddef>
64
73
  #include <cstdint>
65
- #include <float.h>
74
+ #include <cfloat>
66
75
  #include <initializer_list>
67
76
  #include <limits>
68
77
  #include <map>
69
78
  #include <memory>
70
79
  #include <mutex>
71
- #include <stdarg.h>
72
- #include <stdio.h>
73
- #include <stdlib.h>
80
+ #include <cstdarg>
81
+ #include <cstdio>
82
+ #include <cstdlib>
74
83
  #include <string>
75
84
  #include <vector>
85
+ #include <unordered_set>
76
86
 
77
87
  static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
78
88
 
@@ -114,7 +124,10 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
114
124
  err = cudaMallocManaged(ptr, size);
115
125
  #if defined(GGML_USE_HIP)
116
126
  if (err == hipSuccess) {
117
- CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
127
+ // hipMemAdviseSetCoarseGrain is an optional performance hint;
128
+ // ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs).
129
+ cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);
130
+ (void)hipGetLastError(); // clear any error
118
131
  }
119
132
 
120
133
  // fall back to cudaMalloc if not supported (e.g. on Windows)
@@ -195,17 +208,14 @@ static ggml_cuda_device_info ggml_cuda_init() {
195
208
  GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
196
209
 
197
210
  int64_t total_vram = 0;
198
- #ifdef GGML_CUDA_FORCE_MMQ
199
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
200
- #else
201
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
202
- #endif // GGML_CUDA_FORCE_MMQ
203
- #ifdef GGML_CUDA_FORCE_CUBLAS
204
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
205
- #else
206
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
207
- #endif // GGML_CUDA_FORCE_CUBLAS
208
- GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
211
+ for (int id = 0; id < info.device_count; ++id) {
212
+ cudaDeviceProp prop;
213
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
214
+ total_vram += prop.totalGlobalMem;
215
+ }
216
+ GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices (Total VRAM: %zu MiB):\n",
217
+ __func__, info.device_count, (size_t)(total_vram / (1024 * 1024)));
218
+ total_vram = 0;
209
219
 
210
220
  std::vector<std::pair<int, std::string>> turing_devices_without_mma;
211
221
  for (int id = 0; id < info.device_count; ++id) {
@@ -231,10 +241,19 @@ static ggml_cuda_device_info ggml_cuda_init() {
231
241
 
232
242
  info.default_tensor_split[id] = total_vram;
233
243
  total_vram += prop.totalGlobalMem;
234
- info.devices[id].integrated = prop.integrated;
244
+ info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034)
235
245
  info.devices[id].nsm = prop.multiProcessorCount;
236
246
  info.devices[id].smpb = prop.sharedMemPerBlock;
237
247
  info.devices[id].warp_size = prop.warpSize;
248
+
249
+ #ifndef GGML_USE_MUSA
250
+ int supports_coop_launch = 0;
251
+ CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id));
252
+ info.devices[id].supports_cooperative_launch = !!supports_coop_launch;
253
+ #else
254
+ info.devices[id].supports_cooperative_launch = false;
255
+ #endif // !(GGML_USE_MUSA)
256
+
238
257
  #if defined(GGML_USE_HIP)
239
258
  info.devices[id].smpbo = prop.sharedMemPerBlock;
240
259
 
@@ -249,22 +268,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
249
268
  info.devices[id].cc += prop.minor * 0x10;
250
269
  }
251
270
  }
252
- GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
271
+ GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB\n",
253
272
  id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
254
- device_vmm ? "yes" : "no", prop.warpSize);
273
+ device_vmm ? "yes" : "no", prop.warpSize,
274
+ (size_t)(prop.totalGlobalMem / (1024 * 1024)));
255
275
  #elif defined(GGML_USE_MUSA)
256
276
  // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
257
277
  info.devices[id].warp_size = 32;
258
278
  info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
259
279
  info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
260
280
  info.devices[id].cc += prop.minor * 0x10;
261
- GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
262
- id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
281
+ GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n",
282
+ id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
283
+ (size_t)(prop.totalGlobalMem / (1024 * 1024)));
263
284
  #else
264
285
  info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
265
286
  info.devices[id].cc = 100*prop.major + 10*prop.minor;
266
- GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
267
- id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
287
+ GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n",
288
+ id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
289
+ (size_t)(prop.totalGlobalMem / (1024 * 1024)));
268
290
  std::string device_name(prop.name);
269
291
  if (device_name == "NVIDIA GeForce MX450") {
270
292
  turing_devices_without_mma.push_back({ id, device_name });
@@ -273,6 +295,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
273
295
  } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
274
296
  turing_devices_without_mma.push_back({ id, device_name });
275
297
  }
298
+
299
+ // Temporary performance fix:
300
+ // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
301
+ // TODO: Check for future drivers the default scheduling strategy and
302
+ // remove this call again when cudaDeviceScheduleSpin is default.
303
+ if (prop.major == 12 && prop.minor == 1) {
304
+ CUDA_CHECK(cudaSetDevice(id));
305
+ CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
306
+ }
307
+
276
308
  #endif // defined(GGML_USE_HIP)
277
309
  }
278
310
 
@@ -511,7 +543,8 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
511
543
  };
512
544
  #endif // defined(GGML_USE_VMM)
513
545
 
514
- std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
546
+ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device,
547
+ [[maybe_unused]] int stream_no) {
515
548
  #if defined(GGML_USE_VMM)
516
549
  if (ggml_cuda_info().devices[device].vmm) {
517
550
  return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
@@ -1208,6 +1241,34 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
1208
1241
  }
1209
1242
  }
1210
1243
 
1244
+ struct cublas_force_compute_type {
1245
+ bool fp32 = false;
1246
+ bool fp16 = false;
1247
+ };
1248
+
1249
+ static const cublas_force_compute_type & ggml_cuda_cublas_get_force_compute_type() {
1250
+ static const cublas_force_compute_type compute_type = [] {
1251
+ cublas_force_compute_type result;
1252
+
1253
+ const bool ggml_cuda_force_cublas_compute_32f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F") != nullptr;
1254
+ const bool ggml_cuda_force_cublas_compute_16f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F") != nullptr;
1255
+
1256
+ GGML_ASSERT(ggml_cuda_force_cublas_compute_16f_env == false || ggml_cuda_force_cublas_compute_32f_env == false);
1257
+
1258
+ if (ggml_cuda_force_cublas_compute_32f_env) {
1259
+ GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F\n");
1260
+ result.fp32 = true;
1261
+ } else if (ggml_cuda_force_cublas_compute_16f_env) {
1262
+ GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F\n");
1263
+ result.fp16 = true;
1264
+ }
1265
+
1266
+ return result;
1267
+ }();
1268
+
1269
+ return compute_type;
1270
+ }
1271
+
1211
1272
  static void ggml_cuda_op_mul_mat_cublas(
1212
1273
  ggml_backend_cuda_context & ctx,
1213
1274
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
@@ -1290,7 +1351,13 @@ static void ggml_cuda_op_mul_mat_cublas(
1290
1351
 
1291
1352
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1292
1353
 
1293
- if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1354
+ const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
1355
+
1356
+ if (!force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
1357
+ || GGML_CUDA_CC_IS_RDNA4(cc)
1358
+ || cc == GGML_CUDA_CC_VOLTA
1359
+ || force_compute_type.fp32))
1360
+ {
1294
1361
  const float alpha = 1.0f;
1295
1362
  const float beta = 0.0f;
1296
1363
  CUBLAS_CHECK(
@@ -1889,10 +1956,23 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1889
1956
  cudaDataType_t cu_data_type_b = traits::data_type;
1890
1957
  const void * alpha = traits::get_alpha();
1891
1958
  const void * beta = traits::get_beta();
1892
- const float alpha_f32 = 1.0f;
1893
- const float beta_f32 = 0.0f;
1894
1959
 
1895
- if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1960
+ const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
1961
+
1962
+ int id = ggml_cuda_get_device();
1963
+ const int cc = ggml_cuda_info().devices[id].cc;
1964
+ static constexpr bool is_src0_type_f16 = src0_type == GGML_TYPE_F16;
1965
+
1966
+ // bf16 and fp32 are already being computed in fp32 (ensure it using static_assert),
1967
+ // so checking necessity of forced fp32 only for fp16 src0_type
1968
+ static_assert(is_src0_type_f16 || traits::compute_type == CUBLAS_COMPUTE_32F);
1969
+
1970
+ const bool need_compute_32f = is_src0_type_f16 && !force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
1971
+ || GGML_CUDA_CC_IS_RDNA4(cc)
1972
+ || cc == GGML_CUDA_CC_VOLTA
1973
+ || force_compute_type.fp32);
1974
+
1975
+ if (dst->op_params[0] == GGML_PREC_DEFAULT && !need_compute_32f) {
1896
1976
  if constexpr (src0_type == GGML_TYPE_F32) {
1897
1977
  dst_t = (char *) dst_ddf; // Direct F32 output
1898
1978
  } else {
@@ -1902,18 +1982,10 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1902
1982
  }
1903
1983
  } else {
1904
1984
  dst_t = (char *) dst_ddf;
1905
- cu_compute_type = CUBLAS_COMPUTE_32F;
1906
- cu_data_type = CUDA_R_32F;
1907
- alpha = &alpha_f32;
1908
- beta = &beta_f32;
1909
- }
1910
-
1911
- int id = ggml_cuda_get_device();
1912
- const int cc = ggml_cuda_info().devices[id].cc;
1913
- if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1914
- cu_compute_type = CUBLAS_COMPUTE_32F;
1915
- alpha = &alpha_f32;
1916
- beta = &beta_f32;
1985
+ cu_compute_type = batched_mul_mat_traits<GGML_TYPE_F32>::compute_type;
1986
+ cu_data_type = batched_mul_mat_traits<GGML_TYPE_F32>::data_type;
1987
+ alpha = batched_mul_mat_traits<GGML_TYPE_F32>::get_alpha();
1988
+ beta = batched_mul_mat_traits<GGML_TYPE_F32>::get_beta();
1917
1989
  }
1918
1990
 
1919
1991
  GGML_ASSERT(ne12 % ne02 == 0);
@@ -1948,8 +2020,15 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1948
2020
 
1949
2021
  size_t src1_stride_size = sizeof(cuda_t);
1950
2022
 
1951
- dim3 block_dims(ne13, ne12);
1952
- k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
2023
+ const int threads_x = 16;
2024
+ const int threads_y = 16;
2025
+ dim3 block_dims(threads_x, threads_y);
2026
+
2027
+ dim3 grid_dims(
2028
+ (ne13 + threads_x - 1) / threads_x,
2029
+ (ne12 + threads_y - 1) / threads_y
2030
+ );
2031
+ k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
1953
2032
  src0_ptr, src1_ptr, dst_t,
1954
2033
  ptrs_src.get(), ptrs_dst.get(),
1955
2034
  ne12, ne13,
@@ -1998,6 +2077,164 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1998
2077
  }
1999
2078
  }
2000
2079
 
2080
+ static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
2081
+ const ggml_tensor * ffn_gate,
2082
+ const ggml_tensor * glu,
2083
+ const ggml_tensor * ffn_up_bias = nullptr,
2084
+ const ggml_tensor * ffn_gate_bias = nullptr) {
2085
+ const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr;
2086
+
2087
+ if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) {
2088
+ return false;
2089
+ }
2090
+
2091
+ const bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU;
2092
+ const bool is_mul_mat_id = ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU;
2093
+
2094
+ GGML_ASSERT(ffn_up && ffn_gate && glu);
2095
+
2096
+ if (!is_mul_mat && !is_mul_mat_id) {
2097
+ return false;
2098
+ }
2099
+
2100
+ const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID;
2101
+
2102
+ if (has_bias) {
2103
+ if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) {
2104
+ return false;
2105
+ }
2106
+
2107
+ if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) {
2108
+ return false;
2109
+ }
2110
+
2111
+ if (expected_bias_op == GGML_OP_ADD) {
2112
+ const bool up_has_mul = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up;
2113
+ const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate;
2114
+ if (!up_has_mul || !gate_has_mul) {
2115
+ return false;
2116
+ }
2117
+ } else { // GGML_OP_ADD_ID
2118
+ if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) {
2119
+ return false;
2120
+ }
2121
+ if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) {
2122
+ return false;
2123
+ }
2124
+ }
2125
+ } else {
2126
+ if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
2127
+ return false;
2128
+ }
2129
+ }
2130
+
2131
+ if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
2132
+ !ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
2133
+ return false;
2134
+ }
2135
+
2136
+ if (ffn_up->src[1] != ffn_gate->src[1]) {
2137
+ return false;
2138
+ }
2139
+
2140
+ if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
2141
+ return false;
2142
+ }
2143
+
2144
+ static constexpr std::array<ggml_glu_op, 3> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI };
2145
+
2146
+ if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
2147
+ return false;
2148
+ }
2149
+
2150
+ if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
2151
+ return false;
2152
+ }
2153
+
2154
+ const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
2155
+ ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);
2156
+
2157
+ //TODO: add support for fusion for split buffers
2158
+ if (split) {
2159
+ return false;
2160
+ }
2161
+
2162
+ return true;
2163
+ }
2164
+
2165
+ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
2166
+ ggml_tensor * src0 = tensor->src[0];
2167
+ ggml_tensor * src1 = tensor->src[1];
2168
+ const ggml_tensor * dst = tensor;
2169
+
2170
+ const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
2171
+
2172
+ bool use_mul_mat_vec_f =
2173
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
2174
+ src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2175
+
2176
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2177
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
2178
+
2179
+ const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
2180
+ ggml_backend_buft_is_cuda_split(src1->buffer->buft);
2181
+
2182
+ //TODO: add support for fusion for split buffers
2183
+ if (split) {
2184
+ return false;
2185
+ }
2186
+
2187
+ //we only support fusion for ncols_dst = 1
2188
+ if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
2189
+ return false;
2190
+ }
2191
+
2192
+ if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
2193
+ return false;
2194
+ }
2195
+
2196
+
2197
+ return use_mul_mat_vec_f;
2198
+ }
2199
+
2200
+ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
2201
+ ggml_tensor * src0 = tensor->src[0];
2202
+ ggml_tensor * src1 = tensor->src[1];
2203
+ const ggml_tensor * dst = tensor;
2204
+
2205
+ const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
2206
+ ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
2207
+ src0->view_src;
2208
+
2209
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
2210
+ dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
2211
+
2212
+ // fusion is not universally faster on Pascal
2213
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2214
+ if (cc <= GGML_CUDA_CC_PASCAL) {
2215
+ return false;
2216
+ }
2217
+ //we only support fusion for ncols_dst = 1
2218
+ if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
2219
+ return false;
2220
+ }
2221
+
2222
+ if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
2223
+ return false;
2224
+ }
2225
+
2226
+
2227
+ const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
2228
+ ggml_backend_buft_is_cuda_split(src1->buffer->buft);
2229
+
2230
+ //TODO: add support for fusion for split buffers
2231
+ if (split) {
2232
+ return false;
2233
+ }
2234
+
2235
+ return use_mul_mat_vec_q;
2236
+ }
2237
+
2001
2238
  static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2002
2239
  const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
2003
2240
 
@@ -2030,17 +2267,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2030
2267
 
2031
2268
  const int cc = ggml_cuda_info().devices[id].cc;
2032
2269
  const int warp_size = ggml_cuda_info().devices[id].warp_size;
2033
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2034
- use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
2035
- use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
2270
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
2271
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
2272
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
2036
2273
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2037
2274
  }
2038
2275
  } else {
2039
2276
  const int cc = ggml_cuda_info().devices[ctx.device].cc;
2040
2277
  const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
2041
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2042
- use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
2043
- use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
2278
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
2279
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
2280
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
2044
2281
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2045
2282
  }
2046
2283
 
@@ -2096,27 +2333,36 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2096
2333
 
2097
2334
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2098
2335
 
2336
+ // [TAG_MUL_MAT_ID_CUDA_GRAPHS]
2099
2337
  if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2100
- if (ne2 == 1) {
2338
+ static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
2339
+ if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
2101
2340
  if (ggml_is_quantized(src0->type)) {
2102
- ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
2341
+ if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
2342
+ ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
2343
+ return;
2344
+ }
2103
2345
  } else {
2104
- ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
2346
+ if (GGML_CUDA_CC_IS_AMD(cc)) {
2347
+ ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
2348
+ return;
2349
+ }
2105
2350
  }
2106
- return;
2107
2351
  }
2108
2352
 
2109
- if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) {
2353
+ if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
2110
2354
  ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
2111
2355
  return;
2112
2356
  }
2113
2357
 
2114
- if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) {
2358
+ if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src0->nb, src1->ne[2], /*mul_mat_id=*/true)) {
2115
2359
  ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
2116
2360
  return;
2117
2361
  }
2118
2362
  }
2119
2363
 
2364
+ // note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
2365
+ // TODO: add asserts to verify this. should work with CUDA, HIP, etc.
2120
2366
  cudaStream_t stream = ctx.stream();
2121
2367
 
2122
2368
  GGML_ASSERT(nb12 % nb11 == 0);
@@ -2259,6 +2505,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2259
2505
  case GGML_OP_SET_ROWS:
2260
2506
  ggml_cuda_op_set_rows(ctx, dst);
2261
2507
  break;
2508
+ case GGML_OP_SET:
2509
+ ggml_cuda_op_set(ctx, dst);
2510
+ break;
2262
2511
  case GGML_OP_DUP:
2263
2512
  ggml_cuda_dup(ctx, dst);
2264
2513
  break;
@@ -2334,6 +2583,27 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2334
2583
  case GGML_UNARY_OP_ELU:
2335
2584
  ggml_cuda_op_elu(ctx, dst);
2336
2585
  break;
2586
+ case GGML_UNARY_OP_XIELU:
2587
+ ggml_cuda_op_xielu(ctx, dst);
2588
+ break;
2589
+ case GGML_UNARY_OP_FLOOR:
2590
+ ggml_cuda_op_floor(ctx, dst);
2591
+ break;
2592
+ case GGML_UNARY_OP_CEIL:
2593
+ ggml_cuda_op_ceil(ctx, dst);
2594
+ break;
2595
+ case GGML_UNARY_OP_ROUND:
2596
+ ggml_cuda_op_round(ctx, dst);
2597
+ break;
2598
+ case GGML_UNARY_OP_TRUNC:
2599
+ ggml_cuda_op_trunc(ctx, dst);
2600
+ break;
2601
+ case GGML_UNARY_OP_EXPM1:
2602
+ ggml_cuda_op_expm1(ctx, dst);
2603
+ break;
2604
+ case GGML_UNARY_OP_SOFTPLUS:
2605
+ ggml_cuda_op_softplus(ctx, dst);
2606
+ break;
2337
2607
  default:
2338
2608
  return false;
2339
2609
  }
@@ -2437,6 +2707,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2437
2707
  case GGML_OP_PERMUTE:
2438
2708
  case GGML_OP_TRANSPOSE:
2439
2709
  break;
2710
+ case GGML_OP_DIAG:
2711
+ ggml_cuda_op_diag(ctx, dst);
2712
+ break;
2440
2713
  case GGML_OP_DIAG_MASK_INF:
2441
2714
  ggml_cuda_op_diag_mask_inf(ctx, dst);
2442
2715
  break;
@@ -2479,6 +2752,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2479
2752
  case GGML_OP_SUM:
2480
2753
  ggml_cuda_op_sum(ctx, dst);
2481
2754
  break;
2755
+ case GGML_OP_CUMSUM:
2756
+ ggml_cuda_op_cumsum(ctx, dst);
2757
+ break;
2482
2758
  case GGML_OP_SUM_ROWS:
2483
2759
  ggml_cuda_op_sum_rows(ctx, dst);
2484
2760
  break;
@@ -2491,6 +2767,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2491
2767
  case GGML_OP_SSM_SCAN:
2492
2768
  ggml_cuda_op_ssm_scan(ctx, dst);
2493
2769
  break;
2770
+ case GGML_OP_TOP_K:
2771
+ ggml_cuda_op_top_k(ctx, dst);
2772
+ break;
2494
2773
  case GGML_OP_ARGSORT:
2495
2774
  ggml_cuda_op_argsort(ctx, dst);
2496
2775
  break;
@@ -2500,12 +2779,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2500
2779
  case GGML_OP_CROSS_ENTROPY_LOSS:
2501
2780
  ggml_cuda_cross_entropy_loss(ctx, dst);
2502
2781
  break;
2782
+ case GGML_OP_TRI:
2783
+ ggml_cuda_op_tri(ctx, dst);
2784
+ break;
2503
2785
  case GGML_OP_RWKV_WKV6:
2504
2786
  ggml_cuda_op_rwkv_wkv6(ctx, dst);
2505
2787
  break;
2506
2788
  case GGML_OP_GATED_LINEAR_ATTN:
2507
2789
  ggml_cuda_op_gated_linear_attn(ctx, dst);
2508
2790
  break;
2791
+ case GGML_OP_GATED_DELTA_NET:
2792
+ ggml_cuda_op_gated_delta_net(ctx, dst);
2793
+ break;
2509
2794
  case GGML_OP_RWKV_WKV7:
2510
2795
  ggml_cuda_op_rwkv_wkv7(ctx, dst);
2511
2796
  break;
@@ -2518,6 +2803,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2518
2803
  case GGML_OP_OPT_STEP_SGD:
2519
2804
  ggml_cuda_opt_step_sgd(ctx, dst);
2520
2805
  break;
2806
+ case GGML_OP_SOLVE_TRI:
2807
+ ggml_cuda_op_solve_tri(ctx, dst);
2808
+ break;
2809
+ case GGML_OP_FILL:
2810
+ ggml_cuda_op_fill(ctx, dst);
2811
+ break;
2521
2812
  default:
2522
2813
  return false;
2523
2814
  }
@@ -2630,19 +2921,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
2630
2921
  }
2631
2922
 
2632
2923
  #ifdef USE_CUDA_GRAPH
2633
- static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2634
- bool use_cuda_graph) {
2924
+ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
2635
2925
 
2926
+ bool use_cuda_graph = true;
2636
2927
  // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2637
- cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
2638
-
2639
- const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
2640
- const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
2641
- const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
2642
- const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
2643
- const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
2644
- const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
2645
- const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
2646
2928
 
2647
2929
  for (int i = 0; i < cgraph->n_nodes; i++) {
2648
2930
  ggml_tensor * node = cgraph->nodes[i];
@@ -2658,47 +2940,15 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2658
2940
  #endif
2659
2941
  }
2660
2942
 
2661
- if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
2662
- use_cuda_graph = false; // This node type is not supported by CUDA graph capture
2663
- #ifndef NDEBUG
2664
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
2665
- #endif
2666
- }
2667
-
2668
- if (node->op == GGML_OP_ADD &&
2669
- node->src[1] && node->src[1]->ne[1] > 1 &&
2670
- (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
2671
- (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
2672
- strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
2673
- strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
2674
- strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
2675
- strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
2676
- strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
2677
- // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
2678
- // by means of matching node names. See
2679
- // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
2680
- // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2681
- // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2943
+ // [TAG_MUL_MAT_ID_CUDA_GRAPHS]
2944
+ if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
2945
+ // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
2946
+ // TODO: figure out a way to enable for larger batch sizes, without hurting performance
2947
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18958
2682
2948
  use_cuda_graph = false;
2683
2949
  #ifndef NDEBUG
2684
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2685
- #endif
2686
- }
2687
-
2688
- if (node->op == GGML_OP_CPY) {
2689
-
2690
- // Store the pointers which are updated for each token, such that these can be sent
2691
- // to the device and accessed using indirection from CUDA graph
2692
- cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
2693
-
2694
- // store a pointer to each copy op CUDA kernel to identify it later
2695
- void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2696
- if (!ptr) {
2697
- use_cuda_graph = false;
2698
- #ifndef NDEBUG
2699
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
2950
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
2700
2951
  #endif
2701
- }
2702
2952
  }
2703
2953
 
2704
2954
  if (!use_cuda_graph) {
@@ -2706,105 +2956,149 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2706
2956
  }
2707
2957
  }
2708
2958
 
2709
- if (use_cuda_graph) {
2710
- cuda_ctx->cuda_graph->use_cpy_indirection = true;
2711
- // copy pointers to GPU so they can be accessed via indirection within CUDA graph
2712
- ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
2713
- }
2714
-
2715
2959
  return use_cuda_graph;
2716
2960
  }
2717
2961
 
2718
- static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2719
- graph_node_properties->node_address = node->data;
2720
- graph_node_properties->node_op = node->op;
2962
+ static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
2963
+ memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
2964
+ props->node_data = node->data;
2965
+ props->node_op = node->op;
2966
+ props->node_type = node->type;
2967
+ props->flags = node->flags;
2721
2968
  for (int i = 0; i < GGML_MAX_DIMS; i++) {
2722
- graph_node_properties->ne[i] = node->ne[i];
2723
- graph_node_properties->nb[i] = node->nb[i];
2969
+ props->ne[i] = node->ne[i];
2970
+ props->nb[i] = node->nb[i];
2724
2971
  }
2725
2972
  for (int i = 0; i < GGML_MAX_SRC; i++) {
2726
- graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
2973
+ if (!node->src[i]) {
2974
+ continue;
2975
+ }
2976
+
2977
+ props->src_data[i] = node->src[i]->data;
2727
2978
  }
2728
- memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
2979
+ memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
2729
2980
  }
2730
2981
 
2731
- static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2732
- if (node->data != graph_node_properties->node_address &&
2733
- node->op != GGML_OP_CPY &&
2734
- node->op != GGML_OP_VIEW) {
2982
+ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
2983
+ if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
2984
+ return false;
2985
+ }
2986
+
2987
+ if (node->op != props->node_op) {
2735
2988
  return false;
2736
2989
  }
2737
2990
 
2738
- if (node->op != graph_node_properties->node_op) {
2991
+ if (node->type != props->node_type) {
2739
2992
  return false;
2740
2993
  }
2741
2994
 
2742
2995
  for (int i = 0; i < GGML_MAX_DIMS; i++) {
2743
- if (node->ne[i] != graph_node_properties->ne[i]) {
2996
+ if (node->ne[i] != props->ne[i]) {
2744
2997
  return false;
2745
2998
  }
2746
- if (node->nb[i] != graph_node_properties->nb[i]) {
2999
+ if (node->nb[i] != props->nb[i]) {
2747
3000
  return false;
2748
3001
  }
2749
3002
  }
2750
3003
 
2751
- for (int i = 0; i < GGML_MAX_SRC; i++) {
2752
- if (node->src[i] &&
2753
- node->src[i]->data != graph_node_properties->src_address[i] &&
2754
- node->op != GGML_OP_CPY &&
2755
- node->op != GGML_OP_VIEW
2756
- ) {
2757
- return false;
3004
+ if (node->op != GGML_OP_VIEW) {
3005
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
3006
+ if (!node->src[i]) {
3007
+ if (props->src_data[i] != nullptr) {
3008
+ return false;
3009
+ }
3010
+ continue;
3011
+ }
3012
+
3013
+ if (node->src[i]->data != props->src_data[i]) {
3014
+ return false;
3015
+ }
2758
3016
  }
2759
3017
  }
2760
3018
 
2761
- if (node->op == GGML_OP_SCALE &&
2762
- memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
3019
+ if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
3020
+ return false;
3021
+ }
3022
+
3023
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) {
2763
3024
  return false;
2764
3025
  }
2765
3026
 
2766
3027
  return true;
2767
3028
  }
2768
3029
 
2769
- static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
3030
+ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
3031
+ return cgraph->nodes[0];
3032
+ }
2770
3033
 
2771
- bool cuda_graph_update_required = false;
3034
+ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
3035
+ bool res = false;
2772
3036
 
2773
- if (cuda_ctx->cuda_graph->instance == nullptr) {
2774
- cuda_graph_update_required = true;
2775
- }
3037
+ const void * graph_key = ggml_cuda_graph_get_key(cgraph);
3038
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
2776
3039
 
2777
3040
  // Check if the graph size has changed
2778
- if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
2779
- cuda_graph_update_required = true;
2780
- cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
3041
+ if (graph->props.size() != (size_t)cgraph->n_nodes) {
3042
+ res = true;
3043
+ graph->props.resize(cgraph->n_nodes);
2781
3044
  }
2782
3045
 
2783
3046
  // Loop over nodes in GGML graph to determine if CUDA graph update is required
2784
3047
  // and store properties to allow this comparison for the next token
3048
+ std::unordered_set<ggml_tensor *> seen_node;
3049
+ std::vector<ggml_tensor *> srcs_extra;
2785
3050
  for (int i = 0; i < cgraph->n_nodes; i++) {
2786
- bool has_matching_properties = true;
2787
- if (!cuda_graph_update_required) {
2788
- has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
3051
+ bool props_match = true;
3052
+
3053
+ seen_node.insert(cgraph->nodes[i]);
3054
+
3055
+ if (!res) {
3056
+ props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
3057
+ }
3058
+ if (!props_match) {
3059
+ res = true;
3060
+ }
3061
+ ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
3062
+
3063
+ for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3064
+ ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
3065
+ if (src && seen_node.find(src) == seen_node.end()) {
3066
+ srcs_extra.push_back(src);
3067
+ }
3068
+ }
3069
+ }
3070
+
3071
+ if (graph->extra.size() != (size_t) srcs_extra.size()) {
3072
+ res = true;
3073
+ graph->extra.resize(srcs_extra.size());
3074
+ }
3075
+
3076
+ for (size_t i = 0; i < srcs_extra.size(); ++i) {
3077
+ bool props_match = true;
3078
+
3079
+ if (!res) {
3080
+ props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
2789
3081
  }
2790
- if (!has_matching_properties) {
2791
- cuda_graph_update_required = true;
3082
+
3083
+ if (!props_match) {
3084
+ res = true;
2792
3085
  }
2793
- set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
3086
+ ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
2794
3087
  }
2795
3088
 
2796
- return cuda_graph_update_required;
3089
+ return res;
2797
3090
  }
2798
3091
 
2799
- static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
3092
+ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
3093
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
2800
3094
 
2801
3095
  #if CUDART_VERSION >= 12000
2802
3096
  cudaGraphExecUpdateResultInfo result_info;
2803
- cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
3097
+ cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);
2804
3098
  #else
2805
3099
  cudaGraphNode_t errorNode;
2806
3100
  cudaGraphExecUpdateResult result_info;
2807
- cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
3101
+ cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);
2808
3102
  #endif // CUDART_VERSION >= 12000
2809
3103
 
2810
3104
  if (stat == cudaErrorGraphExecUpdateFailure) {
@@ -2815,104 +3109,336 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
2815
3109
  // The pre-existing graph exec cannot be updated due to violated constraints
2816
3110
  // so instead clear error and re-instantiate
2817
3111
  (void)cudaGetLastError();
2818
- CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
2819
- cuda_ctx->cuda_graph->instance = nullptr;
2820
- CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
3112
+ CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
3113
+ graph->instance = nullptr;
3114
+ CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
2821
3115
  } else {
2822
3116
  GGML_ASSERT(stat == cudaSuccess);
2823
3117
  }
2824
3118
  }
2825
- #endif
3119
+ #endif // USE_CUDA_GRAPH
2826
3120
 
2827
- static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
2828
- #ifndef NDEBUG
2829
- const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
2830
- GGML_ASSERT(unary_ops.size() == num_unary);
2831
- #endif
3121
+ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
3122
+ const ggml_tensor * view,
3123
+ const ggml_tensor * set_rows) {
2832
3124
 
2833
- //TODO: remove special case once ggml_can_fuse can handle empty nodes
2834
- std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
2835
- std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
3125
+ if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) {
3126
+ return false;
3127
+ }
3128
+ // ne3 not tested
3129
+ if (rope->src[0]->ne[3] != 1) {
3130
+ return false;
3131
+ }
2836
3132
 
2837
- if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
3133
+ if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
3134
+ return false;
3135
+ }
2838
3136
 
2839
- if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
2840
- return false;
2841
- }
3137
+ if (set_rows->src[1]->type != GGML_TYPE_I64) {
3138
+ return false;
3139
+ }
2842
3140
 
2843
- for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
2844
- if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
2845
- }
2846
- ggml_tensor * softmax = cgraph->nodes[node_idx];
2847
- ggml_tensor * weights = cgraph->nodes[node_idx+8];
3141
+ // The view should flatten two dims of rope into one dim
3142
+ if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) {
3143
+ return false;
3144
+ }
2848
3145
 
2849
- if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
2850
- return true;
2851
- }
3146
+ // Only norm/neox shaders have the fusion code
3147
+ const int mode = ((const int32_t *) rope->op_params)[2];
3148
+ if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
3149
+ return false;
2852
3150
  }
2853
3151
 
2854
- if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
3152
+ return true;
3153
+ }
2855
3154
 
2856
- if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
2857
- return false;
2858
- }
3155
+ static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
3156
+ args.sigmoid = false;
3157
+ args.softmax = false;
3158
+ args.delayed_softmax = false;
3159
+ args.prob_bias = false;
3160
+ args.norm = false;
2859
3161
 
2860
- for (size_t i = 0; i < topk_moe_ops.size(); i++) {
2861
- if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
2862
- }
3162
+ const int n_nodes = cgraph->n_nodes;
3163
+ ggml_tensor ** nodes = cgraph->nodes;
2863
3164
 
2864
- ggml_tensor * softmax = cgraph->nodes[node_idx];
2865
- ggml_tensor * weights = cgraph->nodes[node_idx+4];
2866
- if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
2867
- return true;
3165
+ if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
3166
+ args.softmax = true;
3167
+ }
3168
+
3169
+ if (nodes[node_idx]->op == GGML_OP_UNARY) {
3170
+ if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
3171
+ return false;
2868
3172
  }
3173
+ args.sigmoid = true;
2869
3174
  }
2870
3175
 
2871
- if (!ggml_can_fuse(cgraph, node_idx, ops)) {
2872
- return false;
3176
+ if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
3177
+ args.delayed_softmax = true;
2873
3178
  }
2874
3179
 
2875
- if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2876
- const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
2877
- const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2878
- const ggml_tensor *add = nullptr;
3180
+ node_idx++;
2879
3181
 
2880
- if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
2881
- add = cgraph->nodes[node_idx+2];
3182
+ if (args.sigmoid || args.softmax) {
3183
+ // SOFTMAX -> RESHAPE
3184
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
3185
+ nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3186
+ return false;
2882
3187
  }
3188
+ ggml_tensor * probs_reshaped = nodes[node_idx];
3189
+ node_idx++;
2883
3190
 
2884
- GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
2885
- GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
2886
-
2887
- //rms norm only supports F32
2888
- if (mul->src[0]->type != GGML_TYPE_F32 ||
2889
- mul->src[1]->type != GGML_TYPE_F32 ||
2890
- mul->type != GGML_TYPE_F32) {
3191
+ if (node_idx >= n_nodes) {
2891
3192
  return false;
2892
3193
  }
2893
3194
 
2894
- if (add && (add->src[0]->type != GGML_TYPE_F32 ||
2895
- add->src[1]->type != GGML_TYPE_F32 ||
2896
- add->type != GGML_TYPE_F32) ) {
3195
+ // src of bias add is the unreshaped probs (-2 instead of -1)
3196
+ if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
3197
+ args.prob_bias = true;
3198
+ node_idx++;
3199
+ }
3200
+ // RESHAPE/ADD -> ARGSORT
3201
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
2897
3202
  return false;
2898
3203
  }
2899
3204
 
2900
- //if rms norm is the B operand, then we don't handle broadcast
2901
- if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
3205
+ if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3206
+ return false;
3207
+ } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
2902
3208
  return false;
2903
3209
  }
2904
3210
 
2905
- //rms_norm kernel assumes contigous rows
2906
- if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
3211
+ node_idx++;
3212
+
3213
+ // ARGSORT-> VIEW
3214
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
3215
+ nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
2907
3216
  return false;
2908
3217
  }
3218
+ node_idx++;
2909
3219
 
2910
- if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
3220
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
2911
3221
  return false;
2912
3222
  }
2913
3223
 
2914
- return true;
2915
- }
3224
+ // GET_ROWS
3225
+ if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
3226
+ return false;
3227
+ }
3228
+ node_idx++;
3229
+ } else if (args.delayed_softmax) {
3230
+ if (node_idx - 2 < 0) {
3231
+ return false;
3232
+ }
3233
+ ggml_tensor * probs_reshaped = nodes[node_idx - 2];
3234
+
3235
+ // VIEW->ARGSORT
3236
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
3237
+ nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3238
+ return false;
3239
+ }
3240
+ node_idx++;
3241
+
3242
+ // GET_ROWS
3243
+ if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
3244
+ nodes[node_idx]->src[0] != probs_reshaped) {
3245
+ return false;
3246
+ }
3247
+ node_idx++;
3248
+
3249
+ static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
3250
+
3251
+ for (const ggml_op op : remaining_ops) {
3252
+ if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3253
+ return false;
3254
+ }
3255
+ node_idx++;
3256
+ }
3257
+ }
3258
+
3259
+ // At this point we can check for norm + scale. Everything is now at least valid till the norm
3260
+ if (node_idx >= n_nodes) {
3261
+ return true;
3262
+ }
3263
+
3264
+ if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
3265
+ //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
3266
+ static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
3267
+
3268
+ args.norm = true;
3269
+ for (const ggml_op op : norm_ops) {
3270
+ if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
3271
+ node_idx++;
3272
+ } else {
3273
+ args.norm = false;
3274
+ return true;
3275
+ }
3276
+ }
3277
+
3278
+ // DIV <- CLAMP, RESHAPE
3279
+ if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
3280
+ nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
3281
+ args.norm = false;
3282
+ return true;
3283
+ }
3284
+ node_idx++;
3285
+
3286
+ if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3287
+ args.norm = false;
3288
+ return true;
3289
+ }
3290
+
3291
+ node_idx++;
3292
+ }
3293
+
3294
+ if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
3295
+ args.scale = true;
3296
+ }
3297
+
3298
+ return true;
3299
+ }
3300
+
3301
+ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
3302
+ int node_idx,
3303
+ std::initializer_list<enum ggml_op> ops,
3304
+ std::initializer_list<enum ggml_unary_op> unary_ops) {
3305
+ #ifndef NDEBUG
3306
+ const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
3307
+ GGML_ASSERT(unary_ops.size() == num_unary);
3308
+ #endif
3309
+
3310
+ const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
3311
+ const std::initializer_list<enum ggml_op> & list2) {
3312
+ return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
3313
+ };
3314
+
3315
+ std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
3316
+ std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
3317
+
3318
+ std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
3319
+ std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
3320
+
3321
+ if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&
3322
+ ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {
3323
+ const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
3324
+ const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
3325
+ const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
3326
+ const ggml_tensor * ffn_up_bias = cgraph->nodes[node_idx + 3];
3327
+ const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
3328
+
3329
+ if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
3330
+ return true;
3331
+ }
3332
+ }
3333
+
3334
+ if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&
3335
+ ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
3336
+ const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
3337
+ const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
3338
+ const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
3339
+
3340
+ if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
3341
+ return true;
3342
+ }
3343
+ }
3344
+
3345
+ std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
3346
+
3347
+ if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
3348
+ const ggml_tensor * rope = cgraph->nodes[node_idx];
3349
+ const ggml_tensor * view = cgraph->nodes[node_idx + 1];
3350
+ const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
3351
+
3352
+ if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) {
3353
+ return true;
3354
+ }
3355
+ }
3356
+
3357
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
3358
+ return false;
3359
+ }
3360
+
3361
+ if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
3362
+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
3363
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
3364
+ const ggml_tensor *add = nullptr;
3365
+
3366
+ if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
3367
+ add = cgraph->nodes[node_idx+2];
3368
+ }
3369
+
3370
+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
3371
+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
3372
+
3373
+ //rms norm only supports F32
3374
+ if (mul->src[0]->type != GGML_TYPE_F32 ||
3375
+ mul->src[1]->type != GGML_TYPE_F32 ||
3376
+ mul->type != GGML_TYPE_F32) {
3377
+ return false;
3378
+ }
3379
+
3380
+ if (add && (add->src[0]->type != GGML_TYPE_F32 ||
3381
+ add->src[1]->type != GGML_TYPE_F32 ||
3382
+ add->type != GGML_TYPE_F32) ) {
3383
+ return false;
3384
+ }
3385
+
3386
+ //if rms norm is the B operand, then we don't handle broadcast
3387
+ if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
3388
+ return false;
3389
+ }
3390
+
3391
+ //rms_norm kernel assumes contiguous rows
3392
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
3393
+ return false;
3394
+ }
3395
+
3396
+ if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
3397
+ return false;
3398
+ }
3399
+
3400
+ return true;
3401
+ }
3402
+
3403
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY
3404
+ && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
3405
+ const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
3406
+ const ggml_tensor * silu = cgraph->nodes[node_idx+1];
3407
+
3408
+ if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
3409
+ return false;
3410
+ }
3411
+
3412
+ return true;
3413
+ }
3414
+
3415
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
3416
+ && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
3417
+ const ggml_tensor * unary = cgraph->nodes[node_idx];
3418
+ const ggml_tensor * mul = cgraph->nodes[node_idx+1];
3419
+
3420
+ if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) {
3421
+ return false;
3422
+ }
3423
+
3424
+ if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
3425
+ return false;
3426
+ }
3427
+
3428
+ if (unary->type != mul->type) {
3429
+ return false;
3430
+ }
3431
+
3432
+ const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0];
3433
+ if (other->type != unary->type) {
3434
+ return false;
3435
+ }
3436
+ if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) {
3437
+ return false;
3438
+ }
3439
+
3440
+ return true;
3441
+ }
2916
3442
 
2917
3443
  if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
2918
3444
  && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
@@ -2938,39 +3464,297 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
2938
3464
  return false;
2939
3465
  }
2940
3466
 
2941
- static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2942
- bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
3467
+ // returns whether the write (out) nodes overwrite the read nodes in operation
3468
+ static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph,
3469
+ int node_idx,
3470
+ int node_count,
3471
+ int * out_nodes,
3472
+ int out_count) {
3473
+ auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
3474
+ const int64_t a_start = (int64_t) a->data;
3475
+ const int64_t a_end = a_start + ggml_nbytes(a);
3476
+
3477
+ const int64_t b_start = (int64_t) b->data;
3478
+ const int64_t b_end = b_start + ggml_nbytes(b);
3479
+
3480
+ if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
3481
+ return true;
3482
+ }
3483
+
3484
+ return false;
3485
+ };
3486
+
3487
+ bool is_ok = true;
3488
+ // for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok
3489
+ if (ggml_nrows(cgraph->nodes[node_idx]) == 1) {
3490
+ return true;
3491
+ }
3492
+
3493
+ for (int i = 0; i < out_count; ++i) {
3494
+ const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
3495
+
3496
+ for (int j = node_idx; j < node_idx + node_count; ++j) {
3497
+ // Loop over all srcs of all nodes in the fusion. If the src overlaps
3498
+ // the destination and the src is not an intermediate node that's being
3499
+ // elided, then disable fusion.
3500
+
3501
+ for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3502
+ const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
3503
+
3504
+ if (!src || src->op == GGML_OP_NONE) {
3505
+ continue;
3506
+ }
3507
+
3508
+ if (nodes_overlap(dst, src)) {
3509
+ bool found = false;
3510
+
3511
+ for (int k = node_idx; k < j; ++k) {
3512
+ if (cgraph->nodes[k] == src) {
3513
+ found = true;
3514
+ break;
3515
+ }
3516
+ }
3517
+
3518
+ if (!found) {
3519
+ is_ok = false;
3520
+ break;
3521
+ }
3522
+ }
3523
+ }
3524
+ }
3525
+ }
3526
+
3527
+ return is_ok;
3528
+ }
3529
+
3530
+ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
3531
+ bool graph_evaluated_or_captured = false;
3532
+
2943
3533
  // flag used to determine whether it is an integrated_gpu
2944
- const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
3534
+ const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
3535
+
3536
+ ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
3537
+ bool is_concurrent_event_active = false;
3538
+ ggml_cuda_concurrent_event * concurrent_event = nullptr;
3539
+ bool should_launch_concurrent_events = false;
3540
+
3541
+ const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
3542
+ if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
3543
+ concurrent_event = &stream_ctx.concurrent_events[node];
3544
+
3545
+ is_concurrent_event_active = true;
3546
+
3547
+ GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
3548
+
3549
+ cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
3550
+ GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
3551
+ CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
3552
+
3553
+ for (int i = 1; i <= concurrent_event->n_streams; ++i) {
3554
+ cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
3555
+ CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
3556
+ }
3557
+ }
3558
+ };
2945
3559
 
2946
3560
  while (!graph_evaluated_or_captured) {
2947
3561
  // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2948
3562
  // With the use of CUDA graphs, the execution will be performed by the graph launch.
2949
3563
  if (!use_cuda_graph || cuda_graph_update_required) {
3564
+ [[maybe_unused]] int prev_i = 0;
3565
+
3566
+ if (stream_ctx.concurrent_events.size() > 0) {
3567
+ should_launch_concurrent_events = true;
3568
+ for (const auto & [tensor, event] : stream_ctx.concurrent_events) {
3569
+ should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
3570
+ }
3571
+ }
3572
+
3573
+ if (should_launch_concurrent_events) {
3574
+ // Restore original node order within each concurrent region to enable fusion within streams
3575
+
3576
+ std::unordered_map<const ggml_tensor *, int> node_to_idx;
3577
+ node_to_idx.reserve(cgraph->n_nodes);
3578
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
3579
+ node_to_idx[cgraph->nodes[i]] = i;
3580
+ }
3581
+
3582
+ for (auto & [fork_node, event] : stream_ctx.concurrent_events) {
3583
+ // Find positions of all nodes from this event in the current graph
3584
+ std::vector<int> positions;
3585
+ positions.reserve(event.original_order.size());
3586
+
3587
+ bool all_found = true;
3588
+ for (const ggml_tensor * orig_node : event.original_order) {
3589
+ auto it = node_to_idx.find(orig_node);
3590
+ if (it != node_to_idx.end()) {
3591
+ positions.push_back(it->second);
3592
+ } else {
3593
+ all_found = false;
3594
+ break;
3595
+ }
3596
+ }
3597
+
3598
+ if (!all_found || positions.size() != event.original_order.size()) {
3599
+ continue;
3600
+ }
3601
+
3602
+ // Sort positions to get contiguous range
3603
+ std::vector<int> sorted_positions = positions;
3604
+ std::sort(sorted_positions.begin(), sorted_positions.end());
3605
+
3606
+ bool is_contiguous = true;
3607
+ for (size_t i = 1; i < sorted_positions.size(); ++i) {
3608
+ if (sorted_positions[i] != sorted_positions[i-1] + 1) {
3609
+ is_contiguous = false;
3610
+ break;
3611
+ }
3612
+ }
3613
+
3614
+ if (!is_contiguous) {
3615
+ continue;
3616
+ }
3617
+
3618
+ // Restore original order at the sorted positions
3619
+ int start_pos = sorted_positions[0];
3620
+ for (size_t i = 0; i < event.original_order.size(); ++i) {
3621
+ cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
3622
+ }
3623
+ }
3624
+ } else {
3625
+ stream_ctx.concurrent_events.clear();
3626
+ }
2950
3627
 
2951
3628
  for (int i = 0; i < cgraph->n_nodes; i++) {
2952
3629
  ggml_tensor * node = cgraph->nodes[i];
3630
+ if (is_concurrent_event_active) {
3631
+ GGML_ASSERT(concurrent_event);
3632
+
3633
+ if (node == concurrent_event->join_node) {
3634
+ cuda_ctx->curr_stream_no = 0;
3635
+ for (int i = 1; i <= concurrent_event->n_streams; ++i) {
3636
+ // Wait on join events of forked streams in the main stream
3637
+ CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1],
3638
+ cuda_ctx->stream(cuda_ctx->device, i)));
3639
+ CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1]));
3640
+ }
3641
+
3642
+ is_concurrent_event_active = false;
3643
+ concurrent_event = nullptr;
3644
+ } else {
3645
+ GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end());
3646
+ cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
3647
+ GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
3648
+ }
3649
+ } else if (i - prev_i > 1) {
3650
+ //the previous node was fused
3651
+ const ggml_tensor * prev_node = cgraph->nodes[i - 1];
3652
+ try_launch_concurrent_event(prev_node);
3653
+
3654
+ if (is_concurrent_event_active) {
3655
+ cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
3656
+ GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
3657
+ }
3658
+ }
3659
+
3660
+ #ifdef GGML_CUDA_DEBUG
3661
+ const int nodes_fused = i - prev_i - 1;
3662
+ if (nodes_fused > 0) {
3663
+ GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
3664
+ }
3665
+ #endif
3666
+ prev_i = i;
2953
3667
 
2954
3668
  if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2955
3669
  continue;
2956
3670
  }
2957
3671
 
3672
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
3673
+ continue;
3674
+ }
3675
+
3676
+ // start of fusion operations
2958
3677
  static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
2959
3678
  if (!disable_fusion) {
2960
-
2961
- if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
2962
- ggml_tensor * weights = cgraph->nodes[i+8];
2963
- ggml_tensor * selected_experts = cgraph->nodes[i+3];
2964
- ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
2965
- i += 8;
2966
- continue;
3679
+ ggml_cuda_topk_moe_args args;
3680
+
3681
+ if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
3682
+ cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
3683
+ const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
3684
+
3685
+ std::vector<ggml_op> ops;
3686
+
3687
+ if (can_fuse) {
3688
+ const ggml_tensor * logits = node->src[0];
3689
+ ggml_tensor * weights = nullptr;
3690
+ ggml_tensor * ids = nullptr;
3691
+ const ggml_tensor * bias = nullptr;
3692
+ const ggml_tensor * clamp = nullptr;
3693
+ const ggml_tensor * scale = nullptr;
3694
+
3695
+ if (!args.delayed_softmax) {
3696
+ ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
3697
+ int out_nodes[2]; // nodes which can't be elided
3698
+
3699
+ if (args.prob_bias) {
3700
+ bias = cgraph->nodes[i + 2]->src[1];
3701
+ ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
3702
+ GGML_OP_VIEW, GGML_OP_GET_ROWS });
3703
+ out_nodes[0] = i + 4;
3704
+ ids = cgraph->nodes[i + 4];
3705
+ } else {
3706
+ ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
3707
+ GGML_OP_GET_ROWS });
3708
+ out_nodes[0] = i + 3;
3709
+ ids = cgraph->nodes[i + 3];
3710
+ }
3711
+
3712
+ if (args.norm) {
3713
+ ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
3714
+ GGML_OP_DIV, GGML_OP_RESHAPE });
3715
+ clamp = cgraph->nodes[i + ops.size() - 3];
3716
+ }
3717
+ if (args.scale) {
3718
+ ops.insert(ops.end(), { GGML_OP_SCALE });
3719
+ scale = cgraph->nodes[i + ops.size() - 1];
3720
+ }
3721
+
3722
+ weights = cgraph->nodes[i + ops.size() - 1];
3723
+ out_nodes[1] = i + ops.size() - 1;
3724
+
3725
+ if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
3726
+ ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
3727
+ ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
3728
+ ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
3729
+ i += ops.size() - 1;
3730
+ continue;
3731
+ }
3732
+ } else if (!args.norm && !args.prob_bias) {
3733
+ //special case gpt-oss, no norm, no bias.
3734
+ ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
3735
+ GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
3736
+ weights = cgraph->nodes[i + 5];
3737
+ ids = cgraph->nodes[i + 1];
3738
+ const ggml_tensor * softmax = cgraph->nodes[i + 4];
3739
+
3740
+ int out_nodes[2] = { i + 1, i + 5 };
3741
+ if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
3742
+ ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
3743
+ ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
3744
+ ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
3745
+ i += ops.size() - 1;
3746
+ continue;
3747
+ }
3748
+ }
3749
+ }
2967
3750
  }
2968
3751
 
2969
- if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
2970
- ggml_tensor * weights = cgraph->nodes[i+4];
2971
- ggml_tensor * selected_experts = cgraph->nodes[i+3];
2972
- ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
2973
- i += 4;
3752
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3753
+ ggml_tensor * rope = cgraph->nodes[i];
3754
+ ggml_tensor * set_rows = cgraph->nodes[i + 2];
3755
+
3756
+ ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
3757
+ i += 2;
2974
3758
  continue;
2975
3759
  }
2976
3760
 
@@ -2994,17 +3778,208 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2994
3778
  n_fuse++;
2995
3779
 
2996
3780
  if (n_fuse > 1) {
3781
+ ggml_tensor fused_add_node;
3782
+ memcpy(&fused_add_node, node, sizeof(ggml_tensor));
2997
3783
  for (int j = 0; j < n_fuse - 1; ++j) {
2998
- node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
3784
+ fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
2999
3785
  }
3000
- cgraph->nodes[i + n_fuse - 1]->data = node->data;
3001
- ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
3786
+ fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
3787
+ ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
3002
3788
  i += n_fuse - 1;
3003
3789
 
3004
3790
  continue;
3005
3791
  }
3006
3792
  }
3007
3793
 
3794
+ bool fused_mul_mat_vec = false;
3795
+ int fused_node_count = 0;
3796
+
3797
+ for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3798
+ const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
3799
+
3800
+ if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
3801
+ ggml_tensor * glu = cgraph->nodes[i + 4];
3802
+ ggml_tensor * gate_bias_n = glu->src[0];
3803
+ ggml_tensor * up_bias_n = glu->src[1];
3804
+
3805
+ //we don't assume the order for {gate, up}. Instead infer it from the bias tensor
3806
+ ggml_tensor * gate_n = nullptr;
3807
+ ggml_tensor * up_n = nullptr;
3808
+
3809
+ if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
3810
+ gate_n = cgraph->nodes[i];
3811
+ up_n = cgraph->nodes[i + 2];
3812
+ } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
3813
+ gate_n = cgraph->nodes[i + 2];
3814
+ up_n = cgraph->nodes[i];
3815
+ } else {
3816
+ continue;
3817
+ }
3818
+
3819
+ auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
3820
+ if (op_bias == GGML_OP_ADD) {
3821
+ if (bias_node->src[0] == mul_node) {
3822
+ return bias_node->src[1];
3823
+ }
3824
+ if (bias_node->src[1] == mul_node) {
3825
+ return bias_node->src[0];
3826
+ }
3827
+ return (ggml_tensor *) nullptr;
3828
+ }
3829
+ GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
3830
+ GGML_ASSERT(bias_node->src[0] == mul_node);
3831
+ return bias_node->src[1];
3832
+ };
3833
+
3834
+ ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
3835
+ ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
3836
+
3837
+ if (!up_bias_tensor || !gate_bias_tensor) {
3838
+ continue;
3839
+ }
3840
+
3841
+ // we don't support repeating adds
3842
+ if (bias_op == GGML_OP_ADD &&
3843
+ (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
3844
+ !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
3845
+ continue;
3846
+ }
3847
+
3848
+ const ggml_tensor * src0 = up_n->src[0];
3849
+ const ggml_tensor * src1 = up_n->src[1];
3850
+ const ggml_tensor * ids = up_n->src[2];
3851
+
3852
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
3853
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3854
+ fusion_data.gate = gate_n->src[0];
3855
+ fusion_data.x_bias = up_bias_tensor;
3856
+ fusion_data.gate_bias = gate_bias_tensor;
3857
+ fusion_data.glu_op = ggml_get_glu_op(glu);
3858
+
3859
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3860
+ fused_mul_mat_vec = true;
3861
+ fused_node_count = 5;
3862
+ break;
3863
+ }
3864
+
3865
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
3866
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3867
+ fusion_data.gate = gate_n->src[0];
3868
+ fusion_data.x_bias = up_bias_tensor;
3869
+ fusion_data.gate_bias = gate_bias_tensor;
3870
+ fusion_data.glu_op = ggml_get_glu_op(glu);
3871
+
3872
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3873
+ fused_mul_mat_vec = true;
3874
+ fused_node_count = 5;
3875
+ break;
3876
+ }
3877
+ } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
3878
+ ggml_tensor * glu = cgraph->nodes[i + 2];
3879
+ ggml_tensor * gate = glu->src[0];
3880
+ ggml_tensor * up = glu->src[1];
3881
+
3882
+ bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
3883
+ || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
3884
+
3885
+ if (!ok) continue;
3886
+
3887
+ const ggml_tensor * src0 = up->src[0];
3888
+ const ggml_tensor * src1 = up->src[1];
3889
+ const ggml_tensor * ids = up->src[2];
3890
+
3891
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
3892
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3893
+ fusion_data.gate = gate->src[0];
3894
+ fusion_data.glu_op = ggml_get_glu_op(glu);
3895
+
3896
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3897
+ fused_mul_mat_vec = true;
3898
+ fused_node_count = 3;
3899
+ break;
3900
+ }
3901
+
3902
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
3903
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3904
+ fusion_data.gate = gate->src[0];
3905
+ fusion_data.glu_op = ggml_get_glu_op(glu);
3906
+
3907
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3908
+ fused_mul_mat_vec = true;
3909
+ fused_node_count = 3;
3910
+ break;
3911
+ }
3912
+ }
3913
+ }
3914
+
3915
+ if (fused_mul_mat_vec) {
3916
+ i += fused_node_count - 1;
3917
+ continue;
3918
+ }
3919
+
3920
+ fused_mul_mat_vec = false;
3921
+ fused_node_count = 0;
3922
+
3923
+ for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3924
+ const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
3925
+
3926
+ if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
3927
+ continue;
3928
+ }
3929
+
3930
+ ggml_tensor * mm_node = cgraph->nodes[i];
3931
+ ggml_tensor * bias_node = cgraph->nodes[i + 1];
3932
+
3933
+ ggml_tensor * bias_tensor = nullptr;
3934
+ if (bias_op == GGML_OP_ADD) {
3935
+ if (bias_node->src[0] == mm_node) {
3936
+ bias_tensor = bias_node->src[1];
3937
+ } else if (bias_node->src[1] == mm_node) {
3938
+ bias_tensor = bias_node->src[0];
3939
+ } else {
3940
+ continue;
3941
+ }
3942
+ } else {
3943
+ if (bias_node->src[0] != mm_node) {
3944
+ continue;
3945
+ }
3946
+ bias_tensor = bias_node->src[1];
3947
+ }
3948
+
3949
+ const ggml_tensor * src0 = mm_node->src[0];
3950
+ const ggml_tensor * src1 = mm_node->src[1];
3951
+ const ggml_tensor * ids = mm_node->src[2];
3952
+
3953
+ if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
3954
+ continue;
3955
+ }
3956
+
3957
+ if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
3958
+ continue;
3959
+ }
3960
+
3961
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3962
+ fusion_data.x_bias = bias_tensor;
3963
+
3964
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
3965
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
3966
+ fused_mul_mat_vec = true;
3967
+ fused_node_count = 2;
3968
+ break;
3969
+ }
3970
+
3971
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
3972
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
3973
+ fused_mul_mat_vec = true;
3974
+ fused_node_count = 2;
3975
+ break;
3976
+ }
3977
+ }
3978
+
3979
+ if (fused_mul_mat_vec) {
3980
+ i += fused_node_count - 1;
3981
+ continue;
3982
+ }
3008
3983
 
3009
3984
  if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
3010
3985
  ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
@@ -3018,6 +3993,20 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
3018
3993
  continue;
3019
3994
  }
3020
3995
 
3996
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
3997
+ ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]);
3998
+ i++;
3999
+ continue;
4000
+ }
4001
+
4002
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
4003
+ ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
4004
+ ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
4005
+ ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]);
4006
+ i++;
4007
+ continue;
4008
+ }
4009
+
3021
4010
  if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
3022
4011
  i += 2;
3023
4012
  ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
@@ -3035,24 +4024,29 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
3035
4024
  }
3036
4025
  #else
3037
4026
  GGML_UNUSED(integrated);
3038
- #endif // NDEBUG
4027
+ #endif // NDEBUG
3039
4028
 
3040
4029
  bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
3041
4030
  if (!ok) {
3042
4031
  GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
3043
4032
  }
3044
4033
  GGML_ASSERT(ok);
4034
+
4035
+ if (!is_concurrent_event_active) {
4036
+ try_launch_concurrent_event(node);
4037
+ }
3045
4038
  }
3046
4039
  }
3047
4040
 
3048
4041
  #ifdef USE_CUDA_GRAPH
4042
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3049
4043
  if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
3050
- if (cuda_ctx->cuda_graph->graph != nullptr) {
3051
- CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
3052
- cuda_ctx->cuda_graph->graph = nullptr;
4044
+ if (graph->graph != nullptr) {
4045
+ CUDA_CHECK(cudaGraphDestroy(graph->graph));
4046
+ graph->graph = nullptr;
3053
4047
  }
3054
4048
 
3055
- CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
4049
+ CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));
3056
4050
  graph_evaluated_or_captured = true; // CUDA graph has been captured
3057
4051
 
3058
4052
  std::lock_guard<std::mutex> lock(ggml_cuda_lock);
@@ -3065,74 +4059,82 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
3065
4059
  }
3066
4060
 
3067
4061
  if (use_cuda_graph) {
3068
- if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
3069
- CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
4062
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
4063
+ if (graph->instance == nullptr) { // Create executable graph from captured graph.
4064
+ CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
3070
4065
  }
3071
4066
  if (cuda_graph_update_required) { // Update graph executable
3072
- update_cuda_graph_executable(cuda_ctx);
4067
+ ggml_cuda_graph_update_executable(cuda_ctx, graph_key);
3073
4068
  }
3074
4069
  // Launch graph
3075
- CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
4070
+ CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
3076
4071
  #else
4072
+ GGML_UNUSED(graph_key);
3077
4073
  graph_evaluated_or_captured = true;
3078
4074
  #endif // USE_CUDA_GRAPH
3079
4075
  }
3080
4076
  }
3081
4077
 
3082
- static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
3083
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
3084
-
3085
- ggml_cuda_set_device(cuda_ctx->device);
3086
-
3087
4078
  #ifdef USE_CUDA_GRAPH
3088
- static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
4079
+ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
4080
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3089
4081
 
3090
- // Objects required for CUDA Graph
3091
- if (cuda_ctx->cuda_graph == nullptr) {
3092
- cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
3093
- }
3094
-
3095
- bool use_cuda_graph = true;
3096
- bool cuda_graph_update_required = false;
3097
-
3098
- if (cuda_ctx->cuda_graph->graph == nullptr) {
4082
+ if (graph->graph == nullptr) {
3099
4083
  if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
3100
- cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
3101
- #ifndef NDEBUG
3102
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
3103
- #endif
4084
+ if (!graph->disable_due_to_gpu_arch) {
4085
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
4086
+ }
4087
+ graph->disable_due_to_gpu_arch = true;
3104
4088
  }
3105
4089
  }
3106
4090
 
3107
- // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
3108
- // or previous graph capture failure.
3109
- // Also disable for multi-gpu for now. TO DO investigate
3110
- if (disable_cuda_graphs_due_to_env
3111
- || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
3112
- || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
3113
- || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
3114
- use_cuda_graph = false;
3115
- }
4091
+ return graph->is_enabled();
4092
+ }
4093
+ #endif // USE_CUDA_GRAPH
3116
4094
 
3117
- if (use_cuda_graph) {
3118
- cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
4095
+ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
4096
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
3119
4097
 
3120
- use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
4098
+ ggml_cuda_set_device(cuda_ctx->device);
3121
4099
 
3122
- // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
3123
- if (use_cuda_graph && cuda_graph_update_required) {
3124
- cuda_ctx->cuda_graph->number_consecutive_updates++;
3125
- } else {
3126
- cuda_ctx->cuda_graph->number_consecutive_updates = 0;
3127
- }
4100
+ bool use_cuda_graph = false;
4101
+ bool cuda_graph_update_required = false;
4102
+ const void * graph_key = nullptr;
3128
4103
 
3129
- if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
3130
- cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
3131
- #ifndef NDEBUG
3132
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
3133
- #endif
4104
+ #ifdef USE_CUDA_GRAPH
4105
+ graph_key = ggml_cuda_graph_get_key(cgraph);
4106
+
4107
+ ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
4108
+
4109
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
4110
+ if (graph->is_enabled()) {
4111
+ const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph);
4112
+ if (graph_compatible) {
4113
+ const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
4114
+
4115
+ if (!graph->warmup_complete) {
4116
+ // Warmup: need at least 2 calls with no property change on the 2nd call
4117
+ if (!properties_changed) {
4118
+ graph->warmup_complete = true;
4119
+ GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__);
4120
+ use_cuda_graph = true;
4121
+ cuda_graph_update_required = true;
4122
+ }
4123
+ // else: properties changed or first call - execute directly (use_cuda_graph stays false)
4124
+ } else {
4125
+ // Post-warmup: normal CUDA graph operation
4126
+ if (properties_changed) {
4127
+ // Properties changed - reset warmup, execute directly until stable again
4128
+ graph->warmup_complete = false;
4129
+ GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__);
4130
+ } else {
4131
+ use_cuda_graph = true;
4132
+ cuda_graph_update_required = graph->instance == nullptr;
4133
+ }
4134
+ }
3134
4135
  }
3135
4136
  }
4137
+ #endif // USE_CUDA_GRAPH
3136
4138
 
3137
4139
  if (use_cuda_graph && cuda_graph_update_required) {
3138
4140
  // Start CUDA graph capture
@@ -3144,18 +4146,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
3144
4146
  CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
3145
4147
  }
3146
4148
 
3147
- if (!use_cuda_graph) {
3148
- cuda_ctx->cuda_graph->use_cpy_indirection = false;
3149
- }
3150
-
3151
- #else
3152
- bool use_cuda_graph = false;
3153
- bool cuda_graph_update_required = false;
3154
- #endif // USE_CUDA_GRAPH
3155
-
3156
- bool graph_evaluated_or_captured = false;
3157
-
3158
- evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
4149
+ ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);
3159
4150
 
3160
4151
  return GGML_STATUS_SUCCESS;
3161
4152
  }
@@ -3185,6 +4176,250 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
3185
4176
  }
3186
4177
  }
3187
4178
 
4179
+ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
4180
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
4181
+
4182
+ #ifdef USE_CUDA_GRAPH
4183
+ const void * graph_key = ggml_cuda_graph_get_key(cgraph);
4184
+ const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
4185
+ #else
4186
+ const bool use_cuda_graph = false;
4187
+ GGML_UNUSED(cuda_ctx);
4188
+ GGML_UNUSED(cgraph);
4189
+ #endif
4190
+
4191
+ static bool enable_graph_optimization = [] {
4192
+ const char * env = getenv("GGML_CUDA_GRAPH_OPT");
4193
+ return env != nullptr && atoi(env) == 1;
4194
+ }();
4195
+
4196
+ if (!enable_graph_optimization) {
4197
+ return;
4198
+ }
4199
+
4200
+ ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
4201
+ stream_context.reset();
4202
+
4203
+ if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) {
4204
+ return;
4205
+ }
4206
+
4207
+ // number of out-degrees for a particular node
4208
+ std::unordered_map<const ggml_tensor *, int> fan_out;
4209
+ // reverse mapping of node to index in the cgraph
4210
+ std::unordered_map<const ggml_tensor *, int> node_indices;
4211
+
4212
+ const auto & is_noop = [](const ggml_tensor * node) -> bool {
4213
+ return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE ||
4214
+ node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
4215
+ };
4216
+
4217
+ const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool {
4218
+ for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
4219
+ if (dst->src[s] == src) {
4220
+ return true;
4221
+ }
4222
+ }
4223
+ // implicit dependency if they view the same tensor
4224
+ const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst;
4225
+ const ggml_tensor * src2 = src->view_src ? src->view_src : src;
4226
+ if (dst2 == src2) {
4227
+ return true;
4228
+ }
4229
+ return false;
4230
+ };
4231
+
4232
+ for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
4233
+ const ggml_tensor * node = cgraph->nodes[node_idx];
4234
+ node_indices[node] = node_idx;
4235
+
4236
+ if (is_noop(node)) {
4237
+ continue;
4238
+ }
4239
+ for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
4240
+ const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx];
4241
+ //TODO: check why nrows > 1 fails
4242
+ if (node && !is_noop(node) && ggml_nrows(node) <= 1) {
4243
+ fan_out[src] += 1;
4244
+ }
4245
+ }
4246
+ }
4247
+
4248
+ // Target Q, K, V for concurrency
4249
+ // this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else):
4250
+ // 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm")
4251
+ // 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn")
4252
+ // 3. account for all branches from the fork to the join
4253
+ // 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details)
4254
+ // 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams
4255
+ // See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030
4256
+
4257
+ const int min_fan_out = 3;
4258
+ const int max_fan_out = 3;
4259
+
4260
+ // store {fork_idx, join_idx}
4261
+ std::vector<std::pair<int, int>> concurrent_node_ranges;
4262
+
4263
+ for (const auto & [root_node, count] : fan_out) {
4264
+ if (count >= min_fan_out && count <= max_fan_out) {
4265
+ const int root_node_idx = node_indices[root_node];
4266
+
4267
+ // only optimize for attn_norm
4268
+ // TODO: make this more generic
4269
+ if (!strstr(root_node->name, "attn_norm")) {
4270
+ continue;
4271
+ }
4272
+
4273
+ bool is_part_of_event = false;
4274
+ for (const auto & [start, end] : concurrent_node_ranges) {
4275
+ if (root_node_idx >= start && root_node_idx <= end) {
4276
+ is_part_of_event = true;
4277
+ }
4278
+ }
4279
+
4280
+ if (is_part_of_event) {
4281
+ continue;
4282
+ }
4283
+
4284
+ std::vector<std::vector<const ggml_tensor *>> nodes_per_branch;
4285
+ for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
4286
+ const ggml_tensor * node = cgraph->nodes[i];
4287
+ if (!is_noop(node) && depends_on(node, root_node)) {
4288
+ nodes_per_branch.push_back({ node });
4289
+ }
4290
+ }
4291
+
4292
+ GGML_ASSERT(nodes_per_branch.size() == (size_t) count);
4293
+
4294
+ //find the join point
4295
+ const ggml_tensor * join_node = nullptr;
4296
+
4297
+ const auto & belongs_to_branch = [&](const ggml_tensor * node,
4298
+ const std::vector<const ggml_tensor *> & branch) -> bool {
4299
+ for (const ggml_tensor * n : branch) {
4300
+ if (depends_on(node, n)) {
4301
+ return true;
4302
+ }
4303
+ }
4304
+ return false;
4305
+ };
4306
+
4307
+ for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
4308
+ const ggml_tensor * curr_node = cgraph->nodes[i];
4309
+
4310
+ int num_joins = 0;
4311
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
4312
+ if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) {
4313
+ num_joins++;
4314
+ }
4315
+ }
4316
+
4317
+ if (num_joins >= 2) {
4318
+ join_node = curr_node;
4319
+ break;
4320
+ }
4321
+
4322
+ bool found_branch = false;
4323
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
4324
+ std::vector<const ggml_tensor *> & branch_vec = nodes_per_branch[branch_idx];
4325
+ if (belongs_to_branch(curr_node, branch_vec)) {
4326
+ //continue accumulating
4327
+ if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) {
4328
+ branch_vec.push_back(curr_node);
4329
+ }
4330
+ found_branch = true;
4331
+ }
4332
+ }
4333
+
4334
+ if (!found_branch && is_noop(curr_node)) {
4335
+ // we can put it in any branch because it will be ignored
4336
+ nodes_per_branch[0].push_back({ curr_node });
4337
+ }
4338
+ }
4339
+
4340
+ if (join_node) {
4341
+ //Create ggml_cuda_concurrent_event
4342
+ ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size());
4343
+ concurrent_event.join_node = join_node;
4344
+
4345
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
4346
+ for (const ggml_tensor * n : nodes_per_branch[branch_idx]) {
4347
+ concurrent_event.stream_mapping[n] = branch_idx + 1;
4348
+ }
4349
+ }
4350
+
4351
+ int fork_node_idx = node_indices[root_node];
4352
+ int join_node_idx = node_indices[join_node];
4353
+
4354
+ int current_branch_idx = 0;
4355
+ int current_node_idx = fork_node_idx + 1;
4356
+ const int n_branches = nodes_per_branch.size();
4357
+
4358
+ int total_branch_nodes = 0;
4359
+ for (std::vector<const ggml_tensor *> branch_nodes : nodes_per_branch) {
4360
+ total_branch_nodes += branch_nodes.size();
4361
+ }
4362
+
4363
+ // there are other nodes in the middle which are unaccounted for
4364
+ // usually (cpy) nodes, then ignore this fork
4365
+ if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) {
4366
+ GGML_LOG_DEBUG(
4367
+ "Skipping %s because the number of nodes in the middle is not equal to the total number of "
4368
+ "branch nodes %d != %d\n",
4369
+ root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes);
4370
+ continue;
4371
+ }
4372
+
4373
+ // Save the original order of nodes in this region before interleaving
4374
+ // This is used later to restore grouping for fusion within streams
4375
+ concurrent_event.original_order.reserve(total_branch_nodes);
4376
+ for (int i = fork_node_idx + 1; i < join_node_idx; ++i) {
4377
+ concurrent_event.original_order.push_back(cgraph->nodes[i]);
4378
+ }
4379
+
4380
+ std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
4381
+ GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
4382
+ concurrent_events.emplace(root_node, std::move(concurrent_event));
4383
+ GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node);
4384
+ concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx);
4385
+
4386
+ // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
4387
+ // example transformation:
4388
+ // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->
4389
+ // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]
4390
+ while (current_node_idx < join_node_idx) {
4391
+ std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];
4392
+
4393
+ bool has_node = false;
4394
+ for (std::vector<const ggml_tensor *> branch_node : nodes_per_branch) {
4395
+ has_node |= branch_node.size() > 0;
4396
+ }
4397
+
4398
+ GGML_ASSERT(has_node);
4399
+
4400
+ if (branch_nodes.empty()) {
4401
+ current_branch_idx = (current_branch_idx + 1) % n_branches;
4402
+ continue;
4403
+ }
4404
+
4405
+ cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
4406
+ current_node_idx++;
4407
+ branch_nodes.erase(branch_nodes.begin());
4408
+
4409
+ // append all empty nodes
4410
+ while (!branch_nodes.empty() && is_noop(branch_nodes.front())) {
4411
+ cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
4412
+ current_node_idx++;
4413
+ branch_nodes.erase(branch_nodes.begin());
4414
+ }
4415
+
4416
+ current_branch_idx = (current_branch_idx + 1) % n_branches;
4417
+ }
4418
+ }
4419
+ }
4420
+ }
4421
+ }
4422
+
3188
4423
  static const ggml_backend_i ggml_backend_cuda_interface = {
3189
4424
  /* .get_name = */ ggml_backend_cuda_get_name,
3190
4425
  /* .free = */ ggml_backend_cuda_free,
@@ -3199,7 +4434,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
3199
4434
  /* .graph_compute = */ ggml_backend_cuda_graph_compute,
3200
4435
  /* .event_record = */ ggml_backend_cuda_event_record,
3201
4436
  /* .event_wait = */ ggml_backend_cuda_event_wait,
3202
- /* .graph_optimize = */ NULL,
4437
+ /* .graph_optimize = */ ggml_backend_cuda_graph_optimize,
3203
4438
  };
3204
4439
 
3205
4440
  static ggml_guid_t ggml_backend_cuda_guid() {
@@ -3270,6 +4505,7 @@ struct ggml_backend_cuda_device_context {
3270
4505
  std::string name;
3271
4506
  std::string description;
3272
4507
  std::string pci_bus_id;
4508
+ int op_offload_min_batch_size;
3273
4509
  };
3274
4510
 
3275
4511
  static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -3282,10 +4518,110 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
3282
4518
  return ctx->description.c_str();
3283
4519
  }
3284
4520
 
4521
+ #if defined(__linux__)
4522
+ // Helper function to get available memory from /proc/meminfo for UMA systems
4523
+ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) {
4524
+ FILE * meminfo_file = nullptr;
4525
+ // 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough
4526
+ const size_t BUFFER_SIZE = 2048;
4527
+ auto file_buffer = std::make_unique<char[]>(BUFFER_SIZE);
4528
+ size_t bytes_read = 0;
4529
+ long huge_tlb_total_pages = -1;
4530
+ long huge_tlb_free_pages = -1;
4531
+ long huge_tlb_page_size = -1;
4532
+
4533
+ if (available_memory_kb == nullptr || free_swap_kb == nullptr) {
4534
+ return false;
4535
+ }
4536
+
4537
+ meminfo_file = fopen("/proc/meminfo", "r");
4538
+ if (meminfo_file == nullptr) {
4539
+ GGML_LOG_ERROR("%s: failed to open /proc/meminfo\n", __func__);
4540
+ return false;
4541
+ }
4542
+
4543
+ // Read file into buffer
4544
+ bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file);
4545
+ fclose(meminfo_file);
4546
+
4547
+ if (bytes_read == 0) {
4548
+ GGML_LOG_ERROR("%s: failed to read from /proc/meminfo\n", __func__);
4549
+ return false;
4550
+ }
4551
+ file_buffer[bytes_read] = '\0';
4552
+
4553
+ *available_memory_kb = -1;
4554
+ *free_swap_kb = -1;
4555
+
4556
+ // Parse the file buffer line by line
4557
+ char * line = file_buffer.get();
4558
+ char * line_next;
4559
+ while (line < file_buffer.get() + bytes_read) {
4560
+ // Find the end of the current line
4561
+ line_next = strchr(line, '\n');
4562
+ if (line_next != nullptr) {
4563
+ *line_next = '\0';
4564
+ line_next++;
4565
+ } else {
4566
+ line_next = file_buffer.get() + bytes_read;
4567
+ }
4568
+
4569
+ long value;
4570
+ if (sscanf(line, "MemAvailable: %ld kB", &value) == 1) {
4571
+ *available_memory_kb = value;
4572
+ } else if (sscanf(line, "SwapFree: %ld kB", &value) == 1) {
4573
+ *free_swap_kb = value;
4574
+ } else if (sscanf(line, "HugePages_Total: %ld", &value) == 1) {
4575
+ huge_tlb_total_pages = value;
4576
+ } else if (sscanf(line, "HugePages_Free: %ld", &value) == 1) {
4577
+ huge_tlb_free_pages = value;
4578
+ } else if (sscanf(line, "Hugepagesize: %ld kB", &value) == 1) {
4579
+ huge_tlb_page_size = value;
4580
+ }
4581
+
4582
+ line = line_next;
4583
+ }
4584
+
4585
+ if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) {
4586
+ *available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size;
4587
+
4588
+ // Hugetlbfs pages are not swappable.
4589
+ *free_swap_kb = 0;
4590
+ }
4591
+
4592
+ GGML_LOG_DEBUG("%s: final available_memory_kb: %ld\n", __func__, *available_memory_kb);
4593
+ return true;
4594
+ }
4595
+ #endif // defined(__linux__)
4596
+
3285
4597
  static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
3286
4598
  ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
3287
4599
  ggml_cuda_set_device(ctx->device);
3288
4600
  CUDA_CHECK(cudaMemGetInfo(free, total));
4601
+
4602
+ // ref: https://github.com/ggml-org/llama.cpp/pull/17368
4603
+ #if defined(__linux__)
4604
+ // Check if this is a UMA (Unified Memory Architecture) system
4605
+ cudaDeviceProp prop;
4606
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));
4607
+
4608
+ // Check if UMA is explicitly enabled via environment variable
4609
+ bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
4610
+ bool is_uma = prop.integrated > 0 || uma_env;
4611
+
4612
+ if (is_uma) {
4613
+ // For UMA systems (like DGX Spark), use system memory info
4614
+ long available_memory_kb = 0;
4615
+ long free_swap_kb = 0;
4616
+
4617
+ if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) {
4618
+ *free = (size_t)available_memory_kb * 1024;
4619
+ } else {
4620
+ GGML_LOG_ERROR("%s: /proc/meminfo reading failed, using cudaMemGetInfo\n", __func__);
4621
+ }
4622
+ }
4623
+ #endif // defined(__linux__)
4624
+
3289
4625
  }
3290
4626
 
3291
4627
  static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
@@ -3373,7 +4709,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3373
4709
  case GGML_UNARY_OP_GELU_QUICK:
3374
4710
  case GGML_UNARY_OP_TANH:
3375
4711
  case GGML_UNARY_OP_EXP:
4712
+ case GGML_UNARY_OP_EXPM1:
4713
+ case GGML_UNARY_OP_SOFTPLUS:
3376
4714
  case GGML_UNARY_OP_ELU:
4715
+ case GGML_UNARY_OP_XIELU:
4716
+ case GGML_UNARY_OP_FLOOR:
4717
+ case GGML_UNARY_OP_CEIL:
4718
+ case GGML_UNARY_OP_ROUND:
4719
+ case GGML_UNARY_OP_TRUNC:
4720
+ // TODO: should become:
4721
+ //return ggml_is_contiguous_rows(op->src[0]);
3377
4722
  return ggml_is_contiguous(op->src[0]);
3378
4723
  default:
3379
4724
  return false;
@@ -3488,6 +4833,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3488
4833
  op->src[0]->type == GGML_TYPE_F32 &&
3489
4834
  (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
3490
4835
  } break;
4836
+ case GGML_OP_SET:
4837
+ {
4838
+ const ggml_type t = op->type;
4839
+ return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
4840
+ t == op->src[0]->type &&
4841
+ t == op->src[1]->type;
4842
+ } break;
3491
4843
  case GGML_OP_CPY:
3492
4844
  {
3493
4845
  ggml_type src0_type = op->src[0]->type;
@@ -3536,6 +4888,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3536
4888
  if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
3537
4889
  return true;
3538
4890
  }
4891
+ if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) {
4892
+ return true;
4893
+ }
3539
4894
  if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
3540
4895
  return true;
3541
4896
  }
@@ -3580,7 +4935,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3580
4935
  case GGML_OP_L2_NORM:
3581
4936
  return true;
3582
4937
  case GGML_OP_RMS_NORM_BACK:
3583
- return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
4938
+ return ggml_is_contiguous(op->src[0]);
3584
4939
  break;
3585
4940
  case GGML_OP_NONE:
3586
4941
  case GGML_OP_RESHAPE:
@@ -3642,17 +4997,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3642
4997
  case GGML_OP_CONV_2D_DW:
3643
4998
  case GGML_OP_CONV_TRANSPOSE_2D:
3644
4999
  case GGML_OP_POOL_2D:
3645
- case GGML_OP_SUM:
3646
- case GGML_OP_ACC:
3647
5000
  return true;
5001
+ case GGML_OP_ACC:
5002
+ // TODO: extend support like so:
5003
+ //return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);
5004
+ return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
5005
+ case GGML_OP_SUM:
5006
+ return ggml_is_contiguous_rows(op->src[0]);
5007
+ case GGML_OP_TOP_K:
3648
5008
  case GGML_OP_ARGSORT:
3649
- // TODO: Support arbitrary column width
5009
+ #ifndef GGML_CUDA_USE_CUB
3650
5010
  return op->src[0]->ne[0] <= 1024;
5011
+ #else
5012
+ return true;
5013
+ #endif
3651
5014
  case GGML_OP_SUM_ROWS:
3652
5015
  case GGML_OP_MEAN:
3653
5016
  case GGML_OP_GROUP_NORM:
3654
- case GGML_OP_PAD:
3655
5017
  return ggml_is_contiguous(op->src[0]);
5018
+ case GGML_OP_PAD:
5019
+ return true;
3656
5020
  case GGML_OP_UPSCALE:
3657
5021
  case GGML_OP_PAD_REFLECT_1D:
3658
5022
  case GGML_OP_ARANGE:
@@ -3662,13 +5026,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3662
5026
  case GGML_OP_GATED_LINEAR_ATTN:
3663
5027
  case GGML_OP_RWKV_WKV7:
3664
5028
  return true;
5029
+ case GGML_OP_GATED_DELTA_NET:
5030
+ //TODO: enable once MUSA compiler is solved https://github.com/ggml-org/llama.cpp/pull/19504#issuecomment-4018634327
5031
+ #ifdef GGML_USE_MUSA
5032
+ return false;
5033
+ #else
5034
+ return true;
5035
+ #endif // GGML_USE_MUSA
3665
5036
  case GGML_OP_FLASH_ATTN_EXT:
3666
5037
  return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
3667
5038
  case GGML_OP_CROSS_ENTROPY_LOSS:
3668
5039
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
3669
5040
  case GGML_OP_OPT_STEP_ADAMW:
3670
5041
  case GGML_OP_OPT_STEP_SGD:
5042
+ case GGML_OP_FILL:
5043
+ case GGML_OP_CUMSUM:
5044
+ case GGML_OP_TRI:
5045
+ case GGML_OP_DIAG:
5046
+ case GGML_OP_SOLVE_TRI:
3671
5047
  return true;
5048
+
3672
5049
  default:
3673
5050
  return false;
3674
5051
  }
@@ -3696,11 +5073,9 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
3696
5073
  }
3697
5074
 
3698
5075
  static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
3699
- const int min_batch_size = 32;
3700
-
3701
- return get_op_batch_size(op) >= min_batch_size;
5076
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
3702
5077
 
3703
- GGML_UNUSED(dev);
5078
+ return get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;
3704
5079
  }
3705
5080
 
3706
5081
  static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) {
@@ -3811,6 +5186,16 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
3811
5186
  features.push_back({ "FA_ALL_QUANTS", "1" });
3812
5187
  #endif
3813
5188
 
5189
+ {
5190
+ const auto & info = ggml_cuda_info();
5191
+ for (int id = 0; id < info.device_count; ++id) {
5192
+ if (blackwell_mma_available(info.devices[id].cc)) {
5193
+ features.push_back({ "BLACKWELL_NATIVE_FP4", "1"});
5194
+ break;
5195
+ }
5196
+ }
5197
+ }
5198
+
3814
5199
  #undef _STRINGIFY
3815
5200
  #undef STRINGIFY
3816
5201
 
@@ -3858,13 +5243,13 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
3858
5243
  std::lock_guard<std::mutex> lock(mutex);
3859
5244
  if (!initialized) {
3860
5245
  ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
5246
+ const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
3861
5247
 
3862
5248
  for (int i = 0; i < ggml_cuda_info().device_count; i++) {
3863
5249
  ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
3864
5250
  dev_ctx->device = i;
3865
5251
  dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
3866
5252
 
3867
- ggml_cuda_set_device(i);
3868
5253
  cudaDeviceProp prop;
3869
5254
  CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
3870
5255
  dev_ctx->description = prop.name;
@@ -3872,6 +5257,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
3872
5257
  char pci_bus_id[16] = {};
3873
5258
  snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
3874
5259
  dev_ctx->pci_bus_id = pci_bus_id;
5260
+ dev_ctx->op_offload_min_batch_size = min_batch_size;
3875
5261
 
3876
5262
  ggml_backend_dev_t dev = new ggml_backend_device {
3877
5263
  /* .iface = */ ggml_backend_cuda_device_interface,