whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -53,6 +53,7 @@
53
53
  #include "ggml-cuda/upscale.cuh"
54
54
  #include "ggml-cuda/wkv.cuh"
55
55
  #include "ggml-cuda/gla.cuh"
56
+ #include "ggml-cuda/gated_delta_net.cuh"
56
57
  #include "ggml-cuda/set.cuh"
57
58
  #include "ggml-cuda/set-rows.cuh"
58
59
  #include "ggml-cuda/pad_reflect_1d.cuh"
@@ -70,17 +71,18 @@
70
71
  #include <condition_variable>
71
72
  #include <cstddef>
72
73
  #include <cstdint>
73
- #include <float.h>
74
+ #include <cfloat>
74
75
  #include <initializer_list>
75
76
  #include <limits>
76
77
  #include <map>
77
78
  #include <memory>
78
79
  #include <mutex>
79
- #include <stdarg.h>
80
- #include <stdio.h>
81
- #include <stdlib.h>
80
+ #include <cstdarg>
81
+ #include <cstdio>
82
+ #include <cstdlib>
82
83
  #include <string>
83
84
  #include <vector>
85
+ #include <unordered_set>
84
86
 
85
87
  static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
86
88
 
@@ -122,7 +124,10 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
122
124
  err = cudaMallocManaged(ptr, size);
123
125
  #if defined(GGML_USE_HIP)
124
126
  if (err == hipSuccess) {
125
- CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
127
+ // hipMemAdviseSetCoarseGrain is an optional performance hint;
128
+ // ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs).
129
+ cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);
130
+ (void)hipGetLastError(); // clear any error
126
131
  }
127
132
 
128
133
  // fall back to cudaMalloc if not supported (e.g. on Windows)
@@ -203,7 +208,14 @@ static ggml_cuda_device_info ggml_cuda_init() {
203
208
  GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
204
209
 
205
210
  int64_t total_vram = 0;
206
- GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
211
+ for (int id = 0; id < info.device_count; ++id) {
212
+ cudaDeviceProp prop;
213
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
214
+ total_vram += prop.totalGlobalMem;
215
+ }
216
+ GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices (Total VRAM: %zu MiB):\n",
217
+ __func__, info.device_count, (size_t)(total_vram / (1024 * 1024)));
218
+ total_vram = 0;
207
219
 
208
220
  std::vector<std::pair<int, std::string>> turing_devices_without_mma;
209
221
  for (int id = 0; id < info.device_count; ++id) {
@@ -241,6 +253,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
241
253
  #else
242
254
  info.devices[id].supports_cooperative_launch = false;
243
255
  #endif // !(GGML_USE_MUSA)
256
+
244
257
  #if defined(GGML_USE_HIP)
245
258
  info.devices[id].smpbo = prop.sharedMemPerBlock;
246
259
 
@@ -255,22 +268,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
255
268
  info.devices[id].cc += prop.minor * 0x10;
256
269
  }
257
270
  }
258
- GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
271
+ GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB\n",
259
272
  id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
260
- device_vmm ? "yes" : "no", prop.warpSize);
273
+ device_vmm ? "yes" : "no", prop.warpSize,
274
+ (size_t)(prop.totalGlobalMem / (1024 * 1024)));
261
275
  #elif defined(GGML_USE_MUSA)
262
276
  // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
263
277
  info.devices[id].warp_size = 32;
264
278
  info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
265
279
  info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
266
280
  info.devices[id].cc += prop.minor * 0x10;
267
- GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
268
- id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
281
+ GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n",
282
+ id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
283
+ (size_t)(prop.totalGlobalMem / (1024 * 1024)));
269
284
  #else
270
285
  info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
271
286
  info.devices[id].cc = 100*prop.major + 10*prop.minor;
272
- GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
273
- id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
287
+ GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n",
288
+ id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
289
+ (size_t)(prop.totalGlobalMem / (1024 * 1024)));
274
290
  std::string device_name(prop.name);
275
291
  if (device_name == "NVIDIA GeForce MX450") {
276
292
  turing_devices_without_mma.push_back({ id, device_name });
@@ -285,6 +301,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
285
301
  // TODO: Check for future drivers the default scheduling strategy and
286
302
  // remove this call again when cudaDeviceScheduleSpin is default.
287
303
  if (prop.major == 12 && prop.minor == 1) {
304
+ CUDA_CHECK(cudaSetDevice(id));
288
305
  CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
289
306
  }
290
307
 
@@ -1224,6 +1241,34 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
1224
1241
  }
1225
1242
  }
1226
1243
 
1244
+ struct cublas_force_compute_type {
1245
+ bool fp32 = false;
1246
+ bool fp16 = false;
1247
+ };
1248
+
1249
+ static const cublas_force_compute_type & ggml_cuda_cublas_get_force_compute_type() {
1250
+ static const cublas_force_compute_type compute_type = [] {
1251
+ cublas_force_compute_type result;
1252
+
1253
+ const bool ggml_cuda_force_cublas_compute_32f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F") != nullptr;
1254
+ const bool ggml_cuda_force_cublas_compute_16f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F") != nullptr;
1255
+
1256
+ GGML_ASSERT(ggml_cuda_force_cublas_compute_16f_env == false || ggml_cuda_force_cublas_compute_32f_env == false);
1257
+
1258
+ if (ggml_cuda_force_cublas_compute_32f_env) {
1259
+ GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F\n");
1260
+ result.fp32 = true;
1261
+ } else if (ggml_cuda_force_cublas_compute_16f_env) {
1262
+ GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F\n");
1263
+ result.fp16 = true;
1264
+ }
1265
+
1266
+ return result;
1267
+ }();
1268
+
1269
+ return compute_type;
1270
+ }
1271
+
1227
1272
  static void ggml_cuda_op_mul_mat_cublas(
1228
1273
  ggml_backend_cuda_context & ctx,
1229
1274
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
@@ -1306,7 +1351,13 @@ static void ggml_cuda_op_mul_mat_cublas(
1306
1351
 
1307
1352
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1308
1353
 
1309
- if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1354
+ const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
1355
+
1356
+ if (!force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
1357
+ || GGML_CUDA_CC_IS_RDNA4(cc)
1358
+ || cc == GGML_CUDA_CC_VOLTA
1359
+ || force_compute_type.fp32))
1360
+ {
1310
1361
  const float alpha = 1.0f;
1311
1362
  const float beta = 0.0f;
1312
1363
  CUBLAS_CHECK(
@@ -1905,10 +1956,23 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1905
1956
  cudaDataType_t cu_data_type_b = traits::data_type;
1906
1957
  const void * alpha = traits::get_alpha();
1907
1958
  const void * beta = traits::get_beta();
1908
- const float alpha_f32 = 1.0f;
1909
- const float beta_f32 = 0.0f;
1910
1959
 
1911
- if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1960
+ const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
1961
+
1962
+ int id = ggml_cuda_get_device();
1963
+ const int cc = ggml_cuda_info().devices[id].cc;
1964
+ static constexpr bool is_src0_type_f16 = src0_type == GGML_TYPE_F16;
1965
+
1966
+ // bf16 and fp32 are already being computed in fp32 (ensure it using static_assert),
1967
+ // so checking necessity of forced fp32 only for fp16 src0_type
1968
+ static_assert(is_src0_type_f16 || traits::compute_type == CUBLAS_COMPUTE_32F);
1969
+
1970
+ const bool need_compute_32f = is_src0_type_f16 && !force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
1971
+ || GGML_CUDA_CC_IS_RDNA4(cc)
1972
+ || cc == GGML_CUDA_CC_VOLTA
1973
+ || force_compute_type.fp32);
1974
+
1975
+ if (dst->op_params[0] == GGML_PREC_DEFAULT && !need_compute_32f) {
1912
1976
  if constexpr (src0_type == GGML_TYPE_F32) {
1913
1977
  dst_t = (char *) dst_ddf; // Direct F32 output
1914
1978
  } else {
@@ -1918,18 +1982,10 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1918
1982
  }
1919
1983
  } else {
1920
1984
  dst_t = (char *) dst_ddf;
1921
- cu_compute_type = CUBLAS_COMPUTE_32F;
1922
- cu_data_type = CUDA_R_32F;
1923
- alpha = &alpha_f32;
1924
- beta = &beta_f32;
1925
- }
1926
-
1927
- int id = ggml_cuda_get_device();
1928
- const int cc = ggml_cuda_info().devices[id].cc;
1929
- if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1930
- cu_compute_type = CUBLAS_COMPUTE_32F;
1931
- alpha = &alpha_f32;
1932
- beta = &beta_f32;
1985
+ cu_compute_type = batched_mul_mat_traits<GGML_TYPE_F32>::compute_type;
1986
+ cu_data_type = batched_mul_mat_traits<GGML_TYPE_F32>::data_type;
1987
+ alpha = batched_mul_mat_traits<GGML_TYPE_F32>::get_alpha();
1988
+ beta = batched_mul_mat_traits<GGML_TYPE_F32>::get_beta();
1933
1989
  }
1934
1990
 
1935
1991
  GGML_ASSERT(ne12 % ne02 == 0);
@@ -2277,14 +2333,21 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2277
2333
 
2278
2334
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2279
2335
 
2336
+ // [TAG_MUL_MAT_ID_CUDA_GRAPHS]
2280
2337
  if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2281
- if (ne2 == 1) {
2338
+ static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
2339
+ if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
2282
2340
  if (ggml_is_quantized(src0->type)) {
2283
- ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
2341
+ if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
2342
+ ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
2343
+ return;
2344
+ }
2284
2345
  } else {
2285
- ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
2346
+ if (GGML_CUDA_CC_IS_AMD(cc)) {
2347
+ ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
2348
+ return;
2349
+ }
2286
2350
  }
2287
- return;
2288
2351
  }
2289
2352
 
2290
2353
  if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
@@ -2298,6 +2361,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2298
2361
  }
2299
2362
  }
2300
2363
 
2364
+ // note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
2365
+ // TODO: add asserts to verify this. should work with CUDA, HIP, etc.
2301
2366
  cudaStream_t stream = ctx.stream();
2302
2367
 
2303
2368
  GGML_ASSERT(nb12 % nb11 == 0);
@@ -2723,6 +2788,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2723
2788
  case GGML_OP_GATED_LINEAR_ATTN:
2724
2789
  ggml_cuda_op_gated_linear_attn(ctx, dst);
2725
2790
  break;
2791
+ case GGML_OP_GATED_DELTA_NET:
2792
+ ggml_cuda_op_gated_delta_net(ctx, dst);
2793
+ break;
2726
2794
  case GGML_OP_RWKV_WKV7:
2727
2795
  ggml_cuda_op_rwkv_wkv7(ctx, dst);
2728
2796
  break;
@@ -2858,14 +2926,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
2858
2926
  bool use_cuda_graph = true;
2859
2927
  // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2860
2928
 
2861
- const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
2862
- const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
2863
- const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
2864
- const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
2865
- const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
2866
- const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
2867
- const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
2868
-
2869
2929
  for (int i = 0; i < cgraph->n_nodes; i++) {
2870
2930
  ggml_tensor * node = cgraph->nodes[i];
2871
2931
 
@@ -2880,30 +2940,14 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
2880
2940
  #endif
2881
2941
  }
2882
2942
 
2883
- if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
2884
- use_cuda_graph = false; // This node type is not supported by CUDA graph capture
2885
- #ifndef NDEBUG
2886
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
2887
- #endif
2888
- }
2889
-
2890
- if (node->op == GGML_OP_ADD &&
2891
- node->src[1] && node->src[1]->ne[1] > 1 &&
2892
- (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
2893
- (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
2894
- strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
2895
- strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
2896
- strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
2897
- strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
2898
- strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
2899
- // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
2900
- // by means of matching node names. See
2901
- // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
2902
- // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2903
- // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2943
+ // [TAG_MUL_MAT_ID_CUDA_GRAPHS]
2944
+ if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
2945
+ // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
2946
+ // TODO: figure out a way to enable for larger batch sizes, without hurting performance
2947
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18958
2904
2948
  use_cuda_graph = false;
2905
2949
  #ifndef NDEBUG
2906
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2950
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
2907
2951
  #endif
2908
2952
  }
2909
2953
 
@@ -2916,21 +2960,27 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
2916
2960
  }
2917
2961
 
2918
2962
  static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
2919
- props->node_address = node->data;
2963
+ memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
2964
+ props->node_data = node->data;
2920
2965
  props->node_op = node->op;
2966
+ props->node_type = node->type;
2967
+ props->flags = node->flags;
2921
2968
  for (int i = 0; i < GGML_MAX_DIMS; i++) {
2922
2969
  props->ne[i] = node->ne[i];
2923
2970
  props->nb[i] = node->nb[i];
2924
2971
  }
2925
2972
  for (int i = 0; i < GGML_MAX_SRC; i++) {
2926
- props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
2973
+ if (!node->src[i]) {
2974
+ continue;
2975
+ }
2976
+
2977
+ props->src_data[i] = node->src[i]->data;
2927
2978
  }
2928
2979
  memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
2929
2980
  }
2930
2981
 
2931
2982
  static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
2932
- if (node->data != props->node_address &&
2933
- node->op != GGML_OP_VIEW) {
2983
+ if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
2934
2984
  return false;
2935
2985
  }
2936
2986
 
@@ -2938,6 +2988,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
2938
2988
  return false;
2939
2989
  }
2940
2990
 
2991
+ if (node->type != props->node_type) {
2992
+ return false;
2993
+ }
2994
+
2941
2995
  for (int i = 0; i < GGML_MAX_DIMS; i++) {
2942
2996
  if (node->ne[i] != props->ne[i]) {
2943
2997
  return false;
@@ -2947,73 +3001,104 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
2947
3001
  }
2948
3002
  }
2949
3003
 
2950
- for (int i = 0; i < GGML_MAX_SRC; i++) {
2951
- if (node->src[i] &&
2952
- node->src[i]->data != props->src_address[i] &&
2953
- node->op != GGML_OP_VIEW
2954
- ) {
2955
- return false;
3004
+ if (node->op != GGML_OP_VIEW) {
3005
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
3006
+ if (!node->src[i]) {
3007
+ if (props->src_data[i] != nullptr) {
3008
+ return false;
3009
+ }
3010
+ continue;
3011
+ }
3012
+
3013
+ if (node->src[i]->data != props->src_data[i]) {
3014
+ return false;
3015
+ }
2956
3016
  }
2957
3017
  }
2958
3018
 
2959
- if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
2960
- memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
3019
+ if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
3020
+ return false;
3021
+ }
3022
+
3023
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) {
2961
3024
  return false;
2962
3025
  }
2963
3026
 
2964
3027
  return true;
2965
3028
  }
2966
3029
 
2967
- static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
3030
+ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
3031
+ return cgraph->nodes[0];
3032
+ }
2968
3033
 
3034
+ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
2969
3035
  bool res = false;
2970
3036
 
2971
- if (cuda_ctx->cuda_graph->instance == nullptr) {
2972
- res = true;
2973
- }
3037
+ const void * graph_key = ggml_cuda_graph_get_key(cgraph);
3038
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
2974
3039
 
2975
3040
  // Check if the graph size has changed
2976
- if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
3041
+ if (graph->props.size() != (size_t)cgraph->n_nodes) {
2977
3042
  res = true;
2978
- cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
3043
+ graph->props.resize(cgraph->n_nodes);
2979
3044
  }
2980
3045
 
2981
3046
  // Loop over nodes in GGML graph to determine if CUDA graph update is required
2982
3047
  // and store properties to allow this comparison for the next token
3048
+ std::unordered_set<ggml_tensor *> seen_node;
3049
+ std::vector<ggml_tensor *> srcs_extra;
2983
3050
  for (int i = 0; i < cgraph->n_nodes; i++) {
2984
3051
  bool props_match = true;
3052
+
3053
+ seen_node.insert(cgraph->nodes[i]);
3054
+
2985
3055
  if (!res) {
2986
- props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
3056
+ props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
2987
3057
  }
2988
3058
  if (!props_match) {
2989
3059
  res = true;
2990
3060
  }
2991
- ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
3061
+ ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
3062
+
3063
+ for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3064
+ ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
3065
+ if (src && seen_node.find(src) == seen_node.end()) {
3066
+ srcs_extra.push_back(src);
3067
+ }
3068
+ }
2992
3069
  }
2993
3070
 
2994
- for (int i = 0; i < cgraph->n_leafs; i++) {
2995
- bool props_match= true;
3071
+ if (graph->extra.size() != (size_t) srcs_extra.size()) {
3072
+ res = true;
3073
+ graph->extra.resize(srcs_extra.size());
3074
+ }
3075
+
3076
+ for (size_t i = 0; i < srcs_extra.size(); ++i) {
3077
+ bool props_match = true;
3078
+
2996
3079
  if (!res) {
2997
- props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
3080
+ props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
2998
3081
  }
3082
+
2999
3083
  if (!props_match) {
3000
3084
  res = true;
3001
3085
  }
3002
- ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
3086
+ ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
3003
3087
  }
3004
3088
 
3005
3089
  return res;
3006
3090
  }
3007
3091
 
3008
- static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
3092
+ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
3093
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3009
3094
 
3010
3095
  #if CUDART_VERSION >= 12000
3011
3096
  cudaGraphExecUpdateResultInfo result_info;
3012
- cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
3097
+ cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);
3013
3098
  #else
3014
3099
  cudaGraphNode_t errorNode;
3015
3100
  cudaGraphExecUpdateResult result_info;
3016
- cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
3101
+ cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);
3017
3102
  #endif // CUDART_VERSION >= 12000
3018
3103
 
3019
3104
  if (stat == cudaErrorGraphExecUpdateFailure) {
@@ -3024,14 +3109,14 @@ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_c
3024
3109
  // The pre-existing graph exec cannot be updated due to violated constraints
3025
3110
  // so instead clear error and re-instantiate
3026
3111
  (void)cudaGetLastError();
3027
- CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
3028
- cuda_ctx->cuda_graph->instance = nullptr;
3029
- CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
3112
+ CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
3113
+ graph->instance = nullptr;
3114
+ CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
3030
3115
  } else {
3031
3116
  GGML_ASSERT(stat == cudaSuccess);
3032
3117
  }
3033
3118
  }
3034
- #endif
3119
+ #endif // USE_CUDA_GRAPH
3035
3120
 
3036
3121
  static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
3037
3122
  const ggml_tensor * view,
@@ -3067,63 +3152,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
3067
3152
  return true;
3068
3153
  }
3069
3154
 
3070
- static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
3071
- #ifndef NDEBUG
3072
- const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
3073
- GGML_ASSERT(unary_ops.size() == num_unary);
3074
- #endif
3155
+ static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
3156
+ args.sigmoid = false;
3157
+ args.softmax = false;
3158
+ args.delayed_softmax = false;
3159
+ args.prob_bias = false;
3160
+ args.norm = false;
3075
3161
 
3076
- //TODO: remove special case once ggml_can_fuse can handle empty nodes
3077
- std::initializer_list<enum ggml_op> topk_moe_ops =
3078
- ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
3079
- std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
3080
- ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
3081
- std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
3082
- ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
3162
+ const int n_nodes = cgraph->n_nodes;
3163
+ ggml_tensor ** nodes = cgraph->nodes;
3083
3164
 
3084
- const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
3085
- const std::initializer_list<enum ggml_op> & list2) {
3086
- return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
3087
- };
3088
-
3089
- if (is_equal(topk_moe_ops_with_norm, ops) &&
3090
- ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
3091
- ggml_tensor * softmax = cgraph->nodes[node_idx];
3092
- ggml_tensor * weights = cgraph->nodes[node_idx + 9];
3093
- ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
3094
- ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
3095
- int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
3165
+ if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
3166
+ args.softmax = true;
3167
+ }
3096
3168
 
3097
- if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
3098
- return true;
3169
+ if (nodes[node_idx]->op == GGML_OP_UNARY) {
3170
+ if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
3171
+ return false;
3099
3172
  }
3173
+ args.sigmoid = true;
3100
3174
  }
3101
3175
 
3102
- if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
3103
- ggml_tensor * softmax = cgraph->nodes[node_idx];
3104
- ggml_tensor * weights = cgraph->nodes[node_idx + 4];
3105
- ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
3106
- ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
3107
- int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
3176
+ if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
3177
+ args.delayed_softmax = true;
3178
+ }
3108
3179
 
3109
- if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
3110
- return true;
3180
+ node_idx++;
3181
+
3182
+ if (args.sigmoid || args.softmax) {
3183
+ // SOFTMAX -> RESHAPE
3184
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
3185
+ nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3186
+ return false;
3187
+ }
3188
+ ggml_tensor * probs_reshaped = nodes[node_idx];
3189
+ node_idx++;
3190
+
3191
+ if (node_idx >= n_nodes) {
3192
+ return false;
3193
+ }
3194
+
3195
+ // src of bias add is the unreshaped probs (-2 instead of -1)
3196
+ if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
3197
+ args.prob_bias = true;
3198
+ node_idx++;
3199
+ }
3200
+ // RESHAPE/ADD -> ARGSORT
3201
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
3202
+ return false;
3203
+ }
3204
+
3205
+ if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3206
+ return false;
3207
+ } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
3208
+ return false;
3209
+ }
3210
+
3211
+ node_idx++;
3212
+
3213
+ // ARGSORT-> VIEW
3214
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
3215
+ nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3216
+ return false;
3217
+ }
3218
+ node_idx++;
3219
+
3220
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
3221
+ return false;
3222
+ }
3223
+
3224
+ // GET_ROWS
3225
+ if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
3226
+ return false;
3227
+ }
3228
+ node_idx++;
3229
+ } else if (args.delayed_softmax) {
3230
+ if (node_idx - 2 < 0) {
3231
+ return false;
3111
3232
  }
3233
+ ggml_tensor * probs_reshaped = nodes[node_idx - 2];
3234
+
3235
+ // VIEW->ARGSORT
3236
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
3237
+ nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3238
+ return false;
3239
+ }
3240
+ node_idx++;
3241
+
3242
+ // GET_ROWS
3243
+ if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
3244
+ nodes[node_idx]->src[0] != probs_reshaped) {
3245
+ return false;
3246
+ }
3247
+ node_idx++;
3248
+
3249
+ static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
3250
+
3251
+ for (const ggml_op op : remaining_ops) {
3252
+ if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3253
+ return false;
3254
+ }
3255
+ node_idx++;
3256
+ }
3257
+ }
3258
+
3259
+ // At this point we can check for norm + scale. Everything is now at least valid till the norm
3260
+ if (node_idx >= n_nodes) {
3261
+ return true;
3112
3262
  }
3113
3263
 
3114
- if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
3115
- ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
3116
- ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
3117
- ggml_tensor * weights = cgraph->nodes[node_idx + 5];
3118
- ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
3119
- ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
3120
- int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
3264
+ if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
3265
+ //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
3266
+ static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
3121
3267
 
3122
- if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
3268
+ args.norm = true;
3269
+ for (const ggml_op op : norm_ops) {
3270
+ if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
3271
+ node_idx++;
3272
+ } else {
3273
+ args.norm = false;
3274
+ return true;
3275
+ }
3276
+ }
3277
+
3278
+ // DIV <- CLAMP, RESHAPE
3279
+ if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
3280
+ nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
3281
+ args.norm = false;
3123
3282
  return true;
3124
3283
  }
3284
+ node_idx++;
3285
+
3286
+ if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3287
+ args.norm = false;
3288
+ return true;
3289
+ }
3290
+
3291
+ node_idx++;
3292
+ }
3293
+
3294
+ if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
3295
+ args.scale = true;
3125
3296
  }
3126
3297
 
3298
+ return true;
3299
+ }
3300
+
3301
+ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
3302
+ int node_idx,
3303
+ std::initializer_list<enum ggml_op> ops,
3304
+ std::initializer_list<enum ggml_unary_op> unary_ops) {
3305
+ #ifndef NDEBUG
3306
+ const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
3307
+ GGML_ASSERT(unary_ops.size() == num_unary);
3308
+ #endif
3309
+
3310
+ const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
3311
+ const std::initializer_list<enum ggml_op> & list2) {
3312
+ return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
3313
+ };
3314
+
3127
3315
  std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
3128
3316
  std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
3129
3317
 
@@ -3200,7 +3388,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
3200
3388
  return false;
3201
3389
  }
3202
3390
 
3203
- //rms_norm kernel assumes contigous rows
3391
+ //rms_norm kernel assumes contiguous rows
3204
3392
  if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
3205
3393
  return false;
3206
3394
  }
@@ -3212,6 +3400,46 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
3212
3400
  return true;
3213
3401
  }
3214
3402
 
3403
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY
3404
+ && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
3405
+ const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
3406
+ const ggml_tensor * silu = cgraph->nodes[node_idx+1];
3407
+
3408
+ if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
3409
+ return false;
3410
+ }
3411
+
3412
+ return true;
3413
+ }
3414
+
3415
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
3416
+ && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
3417
+ const ggml_tensor * unary = cgraph->nodes[node_idx];
3418
+ const ggml_tensor * mul = cgraph->nodes[node_idx+1];
3419
+
3420
+ if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) {
3421
+ return false;
3422
+ }
3423
+
3424
+ if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
3425
+ return false;
3426
+ }
3427
+
3428
+ if (unary->type != mul->type) {
3429
+ return false;
3430
+ }
3431
+
3432
+ const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0];
3433
+ if (other->type != unary->type) {
3434
+ return false;
3435
+ }
3436
+ if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) {
3437
+ return false;
3438
+ }
3439
+
3440
+ return true;
3441
+ }
3442
+
3215
3443
  if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
3216
3444
  && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
3217
3445
  const ggml_tensor *scale = cgraph->nodes[node_idx];
@@ -3236,7 +3464,70 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
3236
3464
  return false;
3237
3465
  }
3238
3466
 
3239
- static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
3467
+ // returns whether the write (out) nodes overwrite the read nodes in operation
3468
+ static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph,
3469
+ int node_idx,
3470
+ int node_count,
3471
+ int * out_nodes,
3472
+ int out_count) {
3473
+ auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
3474
+ const int64_t a_start = (int64_t) a->data;
3475
+ const int64_t a_end = a_start + ggml_nbytes(a);
3476
+
3477
+ const int64_t b_start = (int64_t) b->data;
3478
+ const int64_t b_end = b_start + ggml_nbytes(b);
3479
+
3480
+ if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
3481
+ return true;
3482
+ }
3483
+
3484
+ return false;
3485
+ };
3486
+
3487
+ bool is_ok = true;
3488
+ // for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok
3489
+ if (ggml_nrows(cgraph->nodes[node_idx]) == 1) {
3490
+ return true;
3491
+ }
3492
+
3493
+ for (int i = 0; i < out_count; ++i) {
3494
+ const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
3495
+
3496
+ for (int j = node_idx; j < node_idx + node_count; ++j) {
3497
+ // Loop over all srcs of all nodes in the fusion. If the src overlaps
3498
+ // the destination and the src is not an intermediate node that's being
3499
+ // elided, then disable fusion.
3500
+
3501
+ for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3502
+ const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
3503
+
3504
+ if (!src || src->op == GGML_OP_NONE) {
3505
+ continue;
3506
+ }
3507
+
3508
+ if (nodes_overlap(dst, src)) {
3509
+ bool found = false;
3510
+
3511
+ for (int k = node_idx; k < j; ++k) {
3512
+ if (cgraph->nodes[k] == src) {
3513
+ found = true;
3514
+ break;
3515
+ }
3516
+ }
3517
+
3518
+ if (!found) {
3519
+ is_ok = false;
3520
+ break;
3521
+ }
3522
+ }
3523
+ }
3524
+ }
3525
+ }
3526
+
3527
+ return is_ok;
3528
+ }
3529
+
3530
+ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
3240
3531
  bool graph_evaluated_or_captured = false;
3241
3532
 
3242
3533
  // flag used to determine whether it is an integrated_gpu
@@ -3378,39 +3669,84 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
3378
3669
  continue;
3379
3670
  }
3380
3671
 
3672
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
3673
+ continue;
3674
+ }
3381
3675
 
3382
3676
  // start of fusion operations
3383
3677
  static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
3384
3678
  if (!disable_fusion) {
3679
+ ggml_cuda_topk_moe_args args;
3680
+
3681
+ if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
3682
+ cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
3683
+ const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
3684
+
3685
+ std::vector<ggml_op> ops;
3686
+
3687
+ if (can_fuse) {
3688
+ const ggml_tensor * logits = node->src[0];
3689
+ ggml_tensor * weights = nullptr;
3690
+ ggml_tensor * ids = nullptr;
3691
+ const ggml_tensor * bias = nullptr;
3692
+ const ggml_tensor * clamp = nullptr;
3693
+ const ggml_tensor * scale = nullptr;
3694
+
3695
+ if (!args.delayed_softmax) {
3696
+ ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
3697
+ int out_nodes[2]; // nodes which can't be elided
3698
+
3699
+ if (args.prob_bias) {
3700
+ bias = cgraph->nodes[i + 2]->src[1];
3701
+ ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
3702
+ GGML_OP_VIEW, GGML_OP_GET_ROWS });
3703
+ out_nodes[0] = i + 4;
3704
+ ids = cgraph->nodes[i + 4];
3705
+ } else {
3706
+ ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
3707
+ GGML_OP_GET_ROWS });
3708
+ out_nodes[0] = i + 3;
3709
+ ids = cgraph->nodes[i + 3];
3710
+ }
3385
3711
 
3386
- if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
3387
- ggml_tensor * weights = cgraph->nodes[i + 9];
3388
- ggml_tensor * selected_experts = cgraph->nodes[i + 3];
3389
- ggml_tensor * clamp = cgraph->nodes[i + 7];
3390
- ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
3391
- /*delayed softmax*/ false, clamp);
3392
- i += 9;
3393
- continue;
3394
- }
3395
-
3396
- if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
3397
- ggml_tensor * weights = cgraph->nodes[i + 4];
3398
- ggml_tensor * selected_experts = cgraph->nodes[i + 3];
3399
- ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
3400
- /*delayed softmax*/ false);
3401
- i += 4;
3402
- continue;
3403
- }
3712
+ if (args.norm) {
3713
+ ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
3714
+ GGML_OP_DIV, GGML_OP_RESHAPE });
3715
+ clamp = cgraph->nodes[i + ops.size() - 3];
3716
+ }
3717
+ if (args.scale) {
3718
+ ops.insert(ops.end(), { GGML_OP_SCALE });
3719
+ scale = cgraph->nodes[i + ops.size() - 1];
3720
+ }
3404
3721
 
3405
- if (ggml_cuda_can_fuse(cgraph, i,
3406
- ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
3407
- ggml_tensor * weights = cgraph->nodes[i + 5];
3408
- ggml_tensor * ids = cgraph->nodes[i + 1];
3722
+ weights = cgraph->nodes[i + ops.size() - 1];
3723
+ out_nodes[1] = i + ops.size() - 1;
3409
3724
 
3410
- ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
3411
- /*delayed_softmax*/ true);
3412
- i += 5;
3413
- continue;
3725
+ if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
3726
+ ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
3727
+ ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
3728
+ ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
3729
+ i += ops.size() - 1;
3730
+ continue;
3731
+ }
3732
+ } else if (!args.norm && !args.prob_bias) {
3733
+ //special case gpt-oss, no norm, no bias.
3734
+ ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
3735
+ GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
3736
+ weights = cgraph->nodes[i + 5];
3737
+ ids = cgraph->nodes[i + 1];
3738
+ const ggml_tensor * softmax = cgraph->nodes[i + 4];
3739
+
3740
+ int out_nodes[2] = { i + 1, i + 5 };
3741
+ if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
3742
+ ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
3743
+ ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
3744
+ ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
3745
+ i += ops.size() - 1;
3746
+ continue;
3747
+ }
3748
+ }
3749
+ }
3414
3750
  }
3415
3751
 
3416
3752
  if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
@@ -3442,11 +3778,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
3442
3778
  n_fuse++;
3443
3779
 
3444
3780
  if (n_fuse > 1) {
3781
+ ggml_tensor fused_add_node;
3782
+ memcpy(&fused_add_node, node, sizeof(ggml_tensor));
3445
3783
  for (int j = 0; j < n_fuse - 1; ++j) {
3446
- node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
3784
+ fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
3447
3785
  }
3448
- cgraph->nodes[i + n_fuse - 1]->data = node->data;
3449
- ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
3786
+ fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
3787
+ ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
3450
3788
  i += n_fuse - 1;
3451
3789
 
3452
3790
  continue;
@@ -3655,6 +3993,20 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
3655
3993
  continue;
3656
3994
  }
3657
3995
 
3996
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
3997
+ ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]);
3998
+ i++;
3999
+ continue;
4000
+ }
4001
+
4002
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
4003
+ ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
4004
+ ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
4005
+ ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]);
4006
+ i++;
4007
+ continue;
4008
+ }
4009
+
3658
4010
  if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
3659
4011
  i += 2;
3660
4012
  ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
@@ -3687,13 +4039,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
3687
4039
  }
3688
4040
 
3689
4041
  #ifdef USE_CUDA_GRAPH
4042
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3690
4043
  if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
3691
- if (cuda_ctx->cuda_graph->graph != nullptr) {
3692
- CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
3693
- cuda_ctx->cuda_graph->graph = nullptr;
4044
+ if (graph->graph != nullptr) {
4045
+ CUDA_CHECK(cudaGraphDestroy(graph->graph));
4046
+ graph->graph = nullptr;
3694
4047
  }
3695
4048
 
3696
- CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
4049
+ CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));
3697
4050
  graph_evaluated_or_captured = true; // CUDA graph has been captured
3698
4051
 
3699
4052
  std::lock_guard<std::mutex> lock(ggml_cuda_lock);
@@ -3706,41 +4059,38 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
3706
4059
  }
3707
4060
 
3708
4061
  if (use_cuda_graph) {
3709
- if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
3710
- CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
4062
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
4063
+ if (graph->instance == nullptr) { // Create executable graph from captured graph.
4064
+ CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
3711
4065
  }
3712
4066
  if (cuda_graph_update_required) { // Update graph executable
3713
- ggml_cuda_graph_update_executable(cuda_ctx);
4067
+ ggml_cuda_graph_update_executable(cuda_ctx, graph_key);
3714
4068
  }
3715
4069
  // Launch graph
3716
- CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
4070
+ CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
3717
4071
  #else
4072
+ GGML_UNUSED(graph_key);
3718
4073
  graph_evaluated_or_captured = true;
3719
4074
  #endif // USE_CUDA_GRAPH
3720
4075
  }
3721
4076
  }
3722
4077
 
3723
- static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
3724
-
3725
4078
  #ifdef USE_CUDA_GRAPH
4079
+ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
4080
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3726
4081
 
3727
- if (cuda_ctx->cuda_graph == nullptr) {
3728
- cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
3729
- }
3730
-
3731
- if (cuda_ctx->cuda_graph->graph == nullptr) {
4082
+ if (graph->graph == nullptr) {
3732
4083
  if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
3733
- cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
3734
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
4084
+ if (!graph->disable_due_to_gpu_arch) {
4085
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
4086
+ }
4087
+ graph->disable_due_to_gpu_arch = true;
3735
4088
  }
3736
4089
  }
3737
4090
 
3738
- return cuda_ctx->cuda_graph->is_enabled();
3739
- #else
3740
- GGML_UNUSED(cuda_ctx);
3741
- return false;
3742
- #endif // USE_CUDA_GRAPH
4091
+ return graph->is_enabled();
3743
4092
  }
4093
+ #endif // USE_CUDA_GRAPH
3744
4094
 
3745
4095
  static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
3746
4096
  ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
@@ -3749,15 +4099,40 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
3749
4099
 
3750
4100
  bool use_cuda_graph = false;
3751
4101
  bool cuda_graph_update_required = false;
4102
+ const void * graph_key = nullptr;
3752
4103
 
3753
4104
  #ifdef USE_CUDA_GRAPH
3754
- use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
3755
-
3756
- if (cuda_ctx->cuda_graph->is_enabled()) {
3757
- cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
3758
- use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
3759
-
3760
- cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
4105
+ graph_key = ggml_cuda_graph_get_key(cgraph);
4106
+
4107
+ ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
4108
+
4109
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
4110
+ if (graph->is_enabled()) {
4111
+ const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph);
4112
+ if (graph_compatible) {
4113
+ const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
4114
+
4115
+ if (!graph->warmup_complete) {
4116
+ // Warmup: need at least 2 calls with no property change on the 2nd call
4117
+ if (!properties_changed) {
4118
+ graph->warmup_complete = true;
4119
+ GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__);
4120
+ use_cuda_graph = true;
4121
+ cuda_graph_update_required = true;
4122
+ }
4123
+ // else: properties changed or first call - execute directly (use_cuda_graph stays false)
4124
+ } else {
4125
+ // Post-warmup: normal CUDA graph operation
4126
+ if (properties_changed) {
4127
+ // Properties changed - reset warmup, execute directly until stable again
4128
+ graph->warmup_complete = false;
4129
+ GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__);
4130
+ } else {
4131
+ use_cuda_graph = true;
4132
+ cuda_graph_update_required = graph->instance == nullptr;
4133
+ }
4134
+ }
4135
+ }
3761
4136
  }
3762
4137
  #endif // USE_CUDA_GRAPH
3763
4138
 
@@ -3771,7 +4146,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
3771
4146
  CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
3772
4147
  }
3773
4148
 
3774
- ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
4149
+ ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);
3775
4150
 
3776
4151
  return GGML_STATUS_SUCCESS;
3777
4152
  }
@@ -3804,7 +4179,14 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
3804
4179
  static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
3805
4180
  ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
3806
4181
 
3807
- const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
4182
+ #ifdef USE_CUDA_GRAPH
4183
+ const void * graph_key = ggml_cuda_graph_get_key(cgraph);
4184
+ const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
4185
+ #else
4186
+ const bool use_cuda_graph = false;
4187
+ GGML_UNUSED(cuda_ctx);
4188
+ GGML_UNUSED(cgraph);
4189
+ #endif
3808
4190
 
3809
4191
  static bool enable_graph_optimization = [] {
3810
4192
  const char * env = getenv("GGML_CUDA_GRAPH_OPT");
@@ -4335,6 +4717,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4335
4717
  case GGML_UNARY_OP_CEIL:
4336
4718
  case GGML_UNARY_OP_ROUND:
4337
4719
  case GGML_UNARY_OP_TRUNC:
4720
+ // TODO: should become:
4721
+ //return ggml_is_contiguous_rows(op->src[0]);
4338
4722
  return ggml_is_contiguous(op->src[0]);
4339
4723
  default:
4340
4724
  return false;
@@ -4551,7 +4935,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4551
4935
  case GGML_OP_L2_NORM:
4552
4936
  return true;
4553
4937
  case GGML_OP_RMS_NORM_BACK:
4554
- return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
4938
+ return ggml_is_contiguous(op->src[0]);
4555
4939
  break;
4556
4940
  case GGML_OP_NONE:
4557
4941
  case GGML_OP_RESHAPE:
@@ -4613,8 +4997,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4613
4997
  case GGML_OP_CONV_2D_DW:
4614
4998
  case GGML_OP_CONV_TRANSPOSE_2D:
4615
4999
  case GGML_OP_POOL_2D:
4616
- case GGML_OP_ACC:
4617
5000
  return true;
5001
+ case GGML_OP_ACC:
5002
+ // TODO: extend support like so:
5003
+ //return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);
5004
+ return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
4618
5005
  case GGML_OP_SUM:
4619
5006
  return ggml_is_contiguous_rows(op->src[0]);
4620
5007
  case GGML_OP_TOP_K:
@@ -4627,8 +5014,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4627
5014
  case GGML_OP_SUM_ROWS:
4628
5015
  case GGML_OP_MEAN:
4629
5016
  case GGML_OP_GROUP_NORM:
4630
- case GGML_OP_PAD:
4631
5017
  return ggml_is_contiguous(op->src[0]);
5018
+ case GGML_OP_PAD:
5019
+ return true;
4632
5020
  case GGML_OP_UPSCALE:
4633
5021
  case GGML_OP_PAD_REFLECT_1D:
4634
5022
  case GGML_OP_ARANGE:
@@ -4638,6 +5026,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4638
5026
  case GGML_OP_GATED_LINEAR_ATTN:
4639
5027
  case GGML_OP_RWKV_WKV7:
4640
5028
  return true;
5029
+ case GGML_OP_GATED_DELTA_NET:
5030
+ //TODO: enable once MUSA compiler is solved https://github.com/ggml-org/llama.cpp/pull/19504#issuecomment-4018634327
5031
+ #ifdef GGML_USE_MUSA
5032
+ return false;
5033
+ #else
5034
+ return true;
5035
+ #endif // GGML_USE_MUSA
4641
5036
  case GGML_OP_FLASH_ATTN_EXT:
4642
5037
  return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
4643
5038
  case GGML_OP_CROSS_ENTROPY_LOSS: