whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -4,6 +4,7 @@
4
4
 
5
5
  #include "ggml-cuda/common.cuh"
6
6
  #include "ggml-cuda/acc.cuh"
7
+ #include "ggml-cuda/add-id.cuh"
7
8
  #include "ggml-cuda/arange.cuh"
8
9
  #include "ggml-cuda/argmax.cuh"
9
10
  #include "ggml-cuda/argsort.cuh"
@@ -11,38 +12,54 @@
11
12
  #include "ggml-cuda/clamp.cuh"
12
13
  #include "ggml-cuda/concat.cuh"
13
14
  #include "ggml-cuda/conv-transpose-1d.cuh"
15
+ #include "ggml-cuda/conv2d.cuh"
14
16
  #include "ggml-cuda/conv2d-dw.cuh"
15
17
  #include "ggml-cuda/conv2d-transpose.cuh"
16
18
  #include "ggml-cuda/convert.cuh"
17
19
  #include "ggml-cuda/count-equal.cuh"
18
20
  #include "ggml-cuda/cpy.cuh"
19
21
  #include "ggml-cuda/cross-entropy-loss.cuh"
22
+ #include "ggml-cuda/cumsum.cuh"
20
23
  #include "ggml-cuda/diagmask.cuh"
24
+ #include "ggml-cuda/diag.cuh"
21
25
  #include "ggml-cuda/fattn.cuh"
22
26
  #include "ggml-cuda/getrows.cuh"
23
27
  #include "ggml-cuda/im2col.cuh"
28
+ #include "ggml-cuda/mmf.cuh"
24
29
  #include "ggml-cuda/mmq.cuh"
25
- #include "ggml-cuda/mmv.cuh"
30
+ #include "ggml-cuda/mmvf.cuh"
26
31
  #include "ggml-cuda/mmvq.cuh"
27
32
  #include "ggml-cuda/norm.cuh"
28
33
  #include "ggml-cuda/opt-step-adamw.cuh"
34
+ #include "ggml-cuda/opt-step-sgd.cuh"
29
35
  #include "ggml-cuda/out-prod.cuh"
30
36
  #include "ggml-cuda/pad.cuh"
31
37
  #include "ggml-cuda/pool2d.cuh"
32
38
  #include "ggml-cuda/quantize.cuh"
33
39
  #include "ggml-cuda/rope.cuh"
40
+ #include "ggml-cuda/roll.cuh"
34
41
  #include "ggml-cuda/scale.cuh"
42
+ #include "ggml-cuda/softcap.cuh"
35
43
  #include "ggml-cuda/softmax.cuh"
36
44
  #include "ggml-cuda/ssm-conv.cuh"
37
45
  #include "ggml-cuda/ssm-scan.cuh"
38
46
  #include "ggml-cuda/sum.cuh"
39
47
  #include "ggml-cuda/sumrows.cuh"
48
+ #include "ggml-cuda/top-k.cuh"
40
49
  #include "ggml-cuda/mean.cuh"
41
50
  #include "ggml-cuda/tsembd.cuh"
51
+ #include "ggml-cuda/topk-moe.cuh"
42
52
  #include "ggml-cuda/unary.cuh"
43
53
  #include "ggml-cuda/upscale.cuh"
44
54
  #include "ggml-cuda/wkv.cuh"
45
55
  #include "ggml-cuda/gla.cuh"
56
+ #include "ggml-cuda/set.cuh"
57
+ #include "ggml-cuda/set-rows.cuh"
58
+ #include "ggml-cuda/pad_reflect_1d.cuh"
59
+ #include "ggml-cuda/solve_tri.cuh"
60
+ #include "ggml-cuda/tri.cuh"
61
+ #include "ggml-cuda/cumsum.cuh"
62
+ #include "ggml-cuda/fill.cuh"
46
63
  #include "ggml.h"
47
64
 
48
65
  #include <algorithm>
@@ -54,6 +71,7 @@
54
71
  #include <cstddef>
55
72
  #include <cstdint>
56
73
  #include <float.h>
74
+ #include <initializer_list>
57
75
  #include <limits>
58
76
  #include <map>
59
77
  #include <memory>
@@ -124,7 +142,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
124
142
  return err;
125
143
  }
126
144
 
127
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
145
+ #if defined(GGML_USE_HIP)
128
146
  static int ggml_cuda_parse_id(char devName[]) {
129
147
  // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
130
148
  // these values are not stable so this is susceptible to breakage
@@ -171,33 +189,9 @@ static int ggml_cuda_parse_id(char devName[]) {
171
189
  archNum += archMinor;
172
190
  return archNum;
173
191
  }
174
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
192
+ #endif // defined(GGML_USE_HIP)
175
193
 
176
194
  static ggml_cuda_device_info ggml_cuda_init() {
177
- #ifdef __HIP_PLATFORM_AMD__
178
- // Workaround for a rocBLAS bug when using multiple graphics cards:
179
- // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
180
- {
181
- int major_version = 0;
182
- size_t version_length = 0;
183
- if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
184
- std::vector<char> version(version_length+1, '\0');
185
- if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
186
- version.resize(::strlen(version.data()));
187
- int parsed_value = 0;
188
- if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) {
189
- major_version = parsed_value;
190
- }
191
- }
192
- }
193
- if (major_version < 4) {
194
- GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n");
195
- rocblas_initialize();
196
- CUDA_CHECK(cudaDeviceSynchronize());
197
- }
198
- }
199
- #endif
200
-
201
195
  ggml_cuda_device_info info = {};
202
196
 
203
197
  cudaError_t err = cudaGetDeviceCount(&info.device_count);
@@ -209,17 +203,9 @@ static ggml_cuda_device_info ggml_cuda_init() {
209
203
  GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
210
204
 
211
205
  int64_t total_vram = 0;
212
- #ifdef GGML_CUDA_FORCE_MMQ
213
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
214
- #else
215
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
216
- #endif // GGML_CUDA_FORCE_MMQ
217
- #ifdef GGML_CUDA_FORCE_CUBLAS
218
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
219
- #else
220
- GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
221
- #endif // GGML_CUDA_FORCE_CUBLAS
222
206
  GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
207
+
208
+ std::vector<std::pair<int, std::string>> turing_devices_without_mma;
223
209
  for (int id = 0; id < info.device_count; ++id) {
224
210
  int device_vmm = 0;
225
211
 
@@ -243,11 +229,19 @@ static ggml_cuda_device_info ggml_cuda_init() {
243
229
 
244
230
  info.default_tensor_split[id] = total_vram;
245
231
  total_vram += prop.totalGlobalMem;
246
- info.devices[id].integrated = prop.integrated;
232
+ info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034)
247
233
  info.devices[id].nsm = prop.multiProcessorCount;
248
234
  info.devices[id].smpb = prop.sharedMemPerBlock;
249
235
  info.devices[id].warp_size = prop.warpSize;
250
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
236
+
237
+ #ifndef GGML_USE_MUSA
238
+ int supports_coop_launch = 0;
239
+ CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id));
240
+ info.devices[id].supports_cooperative_launch = !!supports_coop_launch;
241
+ #else
242
+ info.devices[id].supports_cooperative_launch = false;
243
+ #endif // !(GGML_USE_MUSA)
244
+ #if defined(GGML_USE_HIP)
251
245
  info.devices[id].smpbo = prop.sharedMemPerBlock;
252
246
 
253
247
  info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
@@ -277,7 +271,34 @@ static ggml_cuda_device_info ggml_cuda_init() {
277
271
  info.devices[id].cc = 100*prop.major + 10*prop.minor;
278
272
  GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
279
273
  id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
280
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
274
+ std::string device_name(prop.name);
275
+ if (device_name == "NVIDIA GeForce MX450") {
276
+ turing_devices_without_mma.push_back({ id, device_name });
277
+ } else if (device_name == "NVIDIA GeForce MX550") {
278
+ turing_devices_without_mma.push_back({ id, device_name });
279
+ } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
280
+ turing_devices_without_mma.push_back({ id, device_name });
281
+ }
282
+
283
+ // Temporary performance fix:
284
+ // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
285
+ // TODO: Check for future drivers the default scheduling strategy and
286
+ // remove this call again when cudaDeviceScheduleSpin is default.
287
+ if (prop.major == 12 && prop.minor == 1) {
288
+ CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
289
+ }
290
+
291
+ #endif // defined(GGML_USE_HIP)
292
+ }
293
+
294
+ if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) {
295
+ GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n");
296
+ for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) {
297
+ GGML_LOG_INFO(
298
+ " Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str());
299
+ }
300
+ GGML_LOG_INFO(
301
+ "Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n");
281
302
  }
282
303
 
283
304
  for (int id = 0; id < info.device_count; ++id) {
@@ -505,7 +526,8 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
505
526
  };
506
527
  #endif // defined(GGML_USE_VMM)
507
528
 
508
- std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
529
+ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device,
530
+ [[maybe_unused]] int stream_no) {
509
531
  #if defined(GGML_USE_VMM)
510
532
  if (ggml_cuda_info().devices[device].vmm) {
511
533
  return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
@@ -1345,9 +1367,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1345
1367
  &beta, dst_dd_i, ldc));
1346
1368
  }
1347
1369
 
1348
- GGML_UNUSED(dst);
1349
- GGML_UNUSED(src1_ddq_i);
1350
- GGML_UNUSED(src1_padded_row_size);
1370
+ GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size);
1351
1371
  }
1352
1372
 
1353
1373
  static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
@@ -1848,6 +1868,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1848
1868
  ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
1849
1869
  ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
1850
1870
 
1871
+ bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
1872
+ bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
1873
+
1851
1874
  // Handle src0
1852
1875
  src0_ptr = (const cuda_t *) src0->data;
1853
1876
 
@@ -1866,6 +1889,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1866
1889
  s11 = ne10;
1867
1890
  s12 = ne11*s11;
1868
1891
  s13 = ne12*s12;
1892
+
1893
+ is_src1_cont_2 = true;
1869
1894
  }
1870
1895
 
1871
1896
  // Setup destination buffer
@@ -1914,15 +1939,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1914
1939
  const int64_t r2 = ne12/ne02;
1915
1940
  const int64_t r3 = ne13/ne03;
1916
1941
 
1917
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1942
+ if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
1943
+ // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
1944
+ const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
1945
+ const int64_t smb = ne12 == 1 ? s13 : s12;
1946
+
1918
1947
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
1919
1948
  // use cublasGemmStridedBatchedEx
1920
1949
  CUBLAS_CHECK(
1921
1950
  cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1922
1951
  ne01, ne11, ne10,
1923
- alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1924
- src1_ptr, cu_data_type_b, s11, s12, // strideB
1925
- beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1952
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
1953
+ src1_ptr, cu_data_type_b, s11, smb, // strideB
1954
+ beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1926
1955
  ne12*ne13,
1927
1956
  cu_compute_type,
1928
1957
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1935,8 +1964,15 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1935
1964
 
1936
1965
  size_t src1_stride_size = sizeof(cuda_t);
1937
1966
 
1938
- dim3 block_dims(ne13, ne12);
1939
- k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1967
+ const int threads_x = 16;
1968
+ const int threads_y = 16;
1969
+ dim3 block_dims(threads_x, threads_y);
1970
+
1971
+ dim3 grid_dims(
1972
+ (ne13 + threads_x - 1) / threads_x,
1973
+ (ne12 + threads_y - 1) / threads_y
1974
+ );
1975
+ k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
1940
1976
  src0_ptr, src1_ptr, dst_t,
1941
1977
  ptrs_src.get(), ptrs_dst.get(),
1942
1978
  ne12, ne13,
@@ -1985,6 +2021,164 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1985
2021
  }
1986
2022
  }
1987
2023
 
2024
+ static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
2025
+ const ggml_tensor * ffn_gate,
2026
+ const ggml_tensor * glu,
2027
+ const ggml_tensor * ffn_up_bias = nullptr,
2028
+ const ggml_tensor * ffn_gate_bias = nullptr) {
2029
+ const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr;
2030
+
2031
+ if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) {
2032
+ return false;
2033
+ }
2034
+
2035
+ const bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU;
2036
+ const bool is_mul_mat_id = ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU;
2037
+
2038
+ GGML_ASSERT(ffn_up && ffn_gate && glu);
2039
+
2040
+ if (!is_mul_mat && !is_mul_mat_id) {
2041
+ return false;
2042
+ }
2043
+
2044
+ const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID;
2045
+
2046
+ if (has_bias) {
2047
+ if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) {
2048
+ return false;
2049
+ }
2050
+
2051
+ if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) {
2052
+ return false;
2053
+ }
2054
+
2055
+ if (expected_bias_op == GGML_OP_ADD) {
2056
+ const bool up_has_mul = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up;
2057
+ const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate;
2058
+ if (!up_has_mul || !gate_has_mul) {
2059
+ return false;
2060
+ }
2061
+ } else { // GGML_OP_ADD_ID
2062
+ if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) {
2063
+ return false;
2064
+ }
2065
+ if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) {
2066
+ return false;
2067
+ }
2068
+ }
2069
+ } else {
2070
+ if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
2071
+ return false;
2072
+ }
2073
+ }
2074
+
2075
+ if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
2076
+ !ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
2077
+ return false;
2078
+ }
2079
+
2080
+ if (ffn_up->src[1] != ffn_gate->src[1]) {
2081
+ return false;
2082
+ }
2083
+
2084
+ if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
2085
+ return false;
2086
+ }
2087
+
2088
+ static constexpr std::array<ggml_glu_op, 3> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI };
2089
+
2090
+ if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
2091
+ return false;
2092
+ }
2093
+
2094
+ if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
2095
+ return false;
2096
+ }
2097
+
2098
+ const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
2099
+ ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);
2100
+
2101
+ //TODO: add support for fusion for split buffers
2102
+ if (split) {
2103
+ return false;
2104
+ }
2105
+
2106
+ return true;
2107
+ }
2108
+
2109
+ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
2110
+ ggml_tensor * src0 = tensor->src[0];
2111
+ ggml_tensor * src1 = tensor->src[1];
2112
+ const ggml_tensor * dst = tensor;
2113
+
2114
+ const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
2115
+
2116
+ bool use_mul_mat_vec_f =
2117
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
2118
+ src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2119
+
2120
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2121
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
2122
+
2123
+ const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
2124
+ ggml_backend_buft_is_cuda_split(src1->buffer->buft);
2125
+
2126
+ //TODO: add support for fusion for split buffers
2127
+ if (split) {
2128
+ return false;
2129
+ }
2130
+
2131
+ //we only support fusion for ncols_dst = 1
2132
+ if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
2133
+ return false;
2134
+ }
2135
+
2136
+ if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
2137
+ return false;
2138
+ }
2139
+
2140
+
2141
+ return use_mul_mat_vec_f;
2142
+ }
2143
+
2144
+ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
2145
+ ggml_tensor * src0 = tensor->src[0];
2146
+ ggml_tensor * src1 = tensor->src[1];
2147
+ const ggml_tensor * dst = tensor;
2148
+
2149
+ const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
2150
+ ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
2151
+ src0->view_src;
2152
+
2153
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
2154
+ dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
2155
+
2156
+ // fusion is not universally faster on Pascal
2157
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2158
+ if (cc <= GGML_CUDA_CC_PASCAL) {
2159
+ return false;
2160
+ }
2161
+ //we only support fusion for ncols_dst = 1
2162
+ if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
2163
+ return false;
2164
+ }
2165
+
2166
+ if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
2167
+ return false;
2168
+ }
2169
+
2170
+
2171
+ const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
2172
+ ggml_backend_buft_is_cuda_split(src1->buffer->buft);
2173
+
2174
+ //TODO: add support for fusion for split buffers
2175
+ if (split) {
2176
+ return false;
2177
+ }
2178
+
2179
+ return use_mul_mat_vec_q;
2180
+ }
2181
+
1988
2182
  static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1989
2183
  const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
1990
2184
 
@@ -1994,7 +2188,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1994
2188
  const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
1995
2189
  && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
1996
2190
 
1997
- bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
2191
+ bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
2192
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2193
+ bool use_mul_mat_f = !ggml_is_quantized(src0->type)
1998
2194
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
1999
2195
  bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
2000
2196
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
@@ -2014,14 +2210,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2014
2210
  }
2015
2211
 
2016
2212
  const int cc = ggml_cuda_info().devices[id].cc;
2017
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2018
- use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
2213
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
2214
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
2215
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
2216
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
2019
2217
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2020
2218
  }
2021
2219
  } else {
2022
2220
  const int cc = ggml_cuda_info().devices[ctx.device].cc;
2023
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2024
- use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
2221
+ const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
2222
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
2223
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
2224
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
2025
2225
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2026
2226
  }
2027
2227
 
@@ -2034,15 +2234,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2034
2234
  //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
2035
2235
 
2036
2236
  //TODO update for generic tensor parallelism
2037
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2237
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2038
2238
  bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039
2239
  bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2040
2240
  bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2041
2241
 
2042
- if (!split && use_mul_mat_vec) {
2242
+ if (!split && use_mul_mat_vec_f) {
2043
2243
  // the custom F16 vector kernel can be used over batched cuBLAS GEMM
2044
2244
  // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
2045
- ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
2245
+ ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
2246
+ } else if (!split && use_mul_mat_f) {
2247
+ ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
2046
2248
  } else if (!split && use_mul_mat_vec_q) {
2047
2249
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
2048
2250
  } else if (!split && use_mul_mat_q) {
@@ -2051,8 +2253,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2051
2253
  && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2052
2254
  // general KQ + KQV multi-batch without FlashAttention
2053
2255
  ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
2054
- } else if (use_mul_mat_vec) {
2055
- ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
2256
+ } else if (use_mul_mat_vec_f) {
2257
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
2056
2258
  } else if (use_mul_mat_vec_q) {
2057
2259
  ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
2058
2260
  } else if (use_mul_mat_q) {
@@ -2080,15 +2282,20 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2080
2282
  if (ggml_is_quantized(src0->type)) {
2081
2283
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
2082
2284
  } else {
2083
- ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
2285
+ ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
2084
2286
  }
2085
2287
  return;
2086
2288
  }
2087
2289
 
2088
- if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) {
2290
+ if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
2089
2291
  ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
2090
2292
  return;
2091
2293
  }
2294
+
2295
+ if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src0->nb, src1->ne[2], /*mul_mat_id=*/true)) {
2296
+ ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
2297
+ return;
2298
+ }
2092
2299
  }
2093
2300
 
2094
2301
  cudaStream_t stream = ctx.stream();
@@ -2230,6 +2437,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2230
2437
  case GGML_OP_GET_ROWS_BACK:
2231
2438
  ggml_cuda_op_get_rows_back(ctx, dst);
2232
2439
  break;
2440
+ case GGML_OP_SET_ROWS:
2441
+ ggml_cuda_op_set_rows(ctx, dst);
2442
+ break;
2443
+ case GGML_OP_SET:
2444
+ ggml_cuda_op_set(ctx, dst);
2445
+ break;
2233
2446
  case GGML_OP_DUP:
2234
2447
  ggml_cuda_dup(ctx, dst);
2235
2448
  break;
@@ -2243,6 +2456,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2243
2456
  case GGML_OP_ADD1: // TODO: more efficient implementation
2244
2457
  ggml_cuda_op_add(ctx, dst);
2245
2458
  break;
2459
+ case GGML_OP_ADD_ID:
2460
+ ggml_cuda_op_add_id(ctx, dst);
2461
+ break;
2246
2462
  case GGML_OP_SUB:
2247
2463
  ggml_cuda_op_sub(ctx, dst);
2248
2464
  break;
@@ -2299,6 +2515,30 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2299
2515
  case GGML_UNARY_OP_EXP:
2300
2516
  ggml_cuda_op_exp(ctx, dst);
2301
2517
  break;
2518
+ case GGML_UNARY_OP_ELU:
2519
+ ggml_cuda_op_elu(ctx, dst);
2520
+ break;
2521
+ case GGML_UNARY_OP_XIELU:
2522
+ ggml_cuda_op_xielu(ctx, dst);
2523
+ break;
2524
+ case GGML_UNARY_OP_FLOOR:
2525
+ ggml_cuda_op_floor(ctx, dst);
2526
+ break;
2527
+ case GGML_UNARY_OP_CEIL:
2528
+ ggml_cuda_op_ceil(ctx, dst);
2529
+ break;
2530
+ case GGML_UNARY_OP_ROUND:
2531
+ ggml_cuda_op_round(ctx, dst);
2532
+ break;
2533
+ case GGML_UNARY_OP_TRUNC:
2534
+ ggml_cuda_op_trunc(ctx, dst);
2535
+ break;
2536
+ case GGML_UNARY_OP_EXPM1:
2537
+ ggml_cuda_op_expm1(ctx, dst);
2538
+ break;
2539
+ case GGML_UNARY_OP_SOFTPLUS:
2540
+ ggml_cuda_op_softplus(ctx, dst);
2541
+ break;
2302
2542
  default:
2303
2543
  return false;
2304
2544
  }
@@ -2314,6 +2554,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2314
2554
  case GGML_GLU_OP_SWIGLU:
2315
2555
  ggml_cuda_op_swiglu(ctx, dst);
2316
2556
  break;
2557
+ case GGML_GLU_OP_SWIGLU_OAI:
2558
+ ggml_cuda_op_swiglu_oai(ctx, dst);
2559
+ break;
2560
+ case GGML_GLU_OP_GEGLU_ERF:
2561
+ ggml_cuda_op_geglu_erf(ctx, dst);
2562
+ break;
2563
+ case GGML_GLU_OP_GEGLU_QUICK:
2564
+ ggml_cuda_op_geglu_quick(ctx, dst);
2565
+ break;
2317
2566
  default:
2318
2567
  return false;
2319
2568
  }
@@ -2336,6 +2585,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2336
2585
  case GGML_OP_PAD:
2337
2586
  ggml_cuda_op_pad(ctx, dst);
2338
2587
  break;
2588
+ case GGML_OP_PAD_REFLECT_1D:
2589
+ ggml_cuda_op_pad_reflect_1d(ctx, dst);
2590
+ break;
2339
2591
  case GGML_OP_ARANGE:
2340
2592
  ggml_cuda_op_arange(ctx, dst);
2341
2593
  break;
@@ -2390,6 +2642,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2390
2642
  case GGML_OP_PERMUTE:
2391
2643
  case GGML_OP_TRANSPOSE:
2392
2644
  break;
2645
+ case GGML_OP_DIAG:
2646
+ ggml_cuda_op_diag(ctx, dst);
2647
+ break;
2393
2648
  case GGML_OP_DIAG_MASK_INF:
2394
2649
  ggml_cuda_op_diag_mask_inf(ctx, dst);
2395
2650
  break;
@@ -2405,9 +2660,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2405
2660
  case GGML_OP_ROPE_BACK:
2406
2661
  ggml_cuda_op_rope_back(ctx, dst);
2407
2662
  break;
2663
+ case GGML_OP_ROLL:
2664
+ ggml_cuda_op_roll(ctx, dst);
2665
+ break;
2408
2666
  case GGML_OP_IM2COL:
2409
2667
  ggml_cuda_op_im2col(ctx, dst);
2410
2668
  break;
2669
+ case GGML_OP_IM2COL_3D:
2670
+ ggml_cuda_op_im2col_3d(ctx, dst);
2671
+ break;
2672
+ case GGML_OP_CONV_2D:
2673
+ ggml_cuda_op_conv2d(ctx, dst);
2674
+ break;
2411
2675
  case GGML_OP_CONV_2D_DW:
2412
2676
  ggml_cuda_op_conv2d_dw(ctx, dst);
2413
2677
  break;
@@ -2423,6 +2687,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2423
2687
  case GGML_OP_SUM:
2424
2688
  ggml_cuda_op_sum(ctx, dst);
2425
2689
  break;
2690
+ case GGML_OP_CUMSUM:
2691
+ ggml_cuda_op_cumsum(ctx, dst);
2692
+ break;
2426
2693
  case GGML_OP_SUM_ROWS:
2427
2694
  ggml_cuda_op_sum_rows(ctx, dst);
2428
2695
  break;
@@ -2435,6 +2702,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2435
2702
  case GGML_OP_SSM_SCAN:
2436
2703
  ggml_cuda_op_ssm_scan(ctx, dst);
2437
2704
  break;
2705
+ case GGML_OP_TOP_K:
2706
+ ggml_cuda_op_top_k(ctx, dst);
2707
+ break;
2438
2708
  case GGML_OP_ARGSORT:
2439
2709
  ggml_cuda_op_argsort(ctx, dst);
2440
2710
  break;
@@ -2444,6 +2714,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2444
2714
  case GGML_OP_CROSS_ENTROPY_LOSS:
2445
2715
  ggml_cuda_cross_entropy_loss(ctx, dst);
2446
2716
  break;
2717
+ case GGML_OP_TRI:
2718
+ ggml_cuda_op_tri(ctx, dst);
2719
+ break;
2447
2720
  case GGML_OP_RWKV_WKV6:
2448
2721
  ggml_cuda_op_rwkv_wkv6(ctx, dst);
2449
2722
  break;
@@ -2459,6 +2732,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2459
2732
  case GGML_OP_OPT_STEP_ADAMW:
2460
2733
  ggml_cuda_opt_step_adamw(ctx, dst);
2461
2734
  break;
2735
+ case GGML_OP_OPT_STEP_SGD:
2736
+ ggml_cuda_opt_step_sgd(ctx, dst);
2737
+ break;
2738
+ case GGML_OP_SOLVE_TRI:
2739
+ ggml_cuda_op_solve_tri(ctx, dst);
2740
+ break;
2741
+ case GGML_OP_FILL:
2742
+ ggml_cuda_op_fill(ctx, dst);
2743
+ break;
2462
2744
  default:
2463
2745
  return false;
2464
2746
  }
@@ -2571,11 +2853,18 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
2571
2853
  }
2572
2854
 
2573
2855
  #ifdef USE_CUDA_GRAPH
2574
- static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2575
- bool use_cuda_graph) {
2856
+ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
2576
2857
 
2858
+ bool use_cuda_graph = true;
2577
2859
  // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2578
- cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
2860
+
2861
+ const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
2862
+ const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
2863
+ const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
2864
+ const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
2865
+ const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
2866
+ const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
2867
+ const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
2579
2868
 
2580
2869
  for (int i = 0; i < cgraph->n_nodes; i++) {
2581
2870
  ggml_tensor * node = cgraph->nodes[i];
@@ -2598,127 +2887,125 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2598
2887
  #endif
2599
2888
  }
2600
2889
 
2601
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
2602
- // disable CUDA graphs for batch size > 1 for now.
2603
- // Changes in batch size or context size can cause changes to the grid size of some kernels.
2890
+ if (node->op == GGML_OP_ADD &&
2891
+ node->src[1] && node->src[1]->ne[1] > 1 &&
2892
+ (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
2893
+ (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
2894
+ strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
2895
+ strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
2896
+ strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
2897
+ strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
2898
+ strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
2899
+ // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
2900
+ // by means of matching node names. See
2901
+ // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
2902
+ // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2903
+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2604
2904
  use_cuda_graph = false;
2605
2905
  #ifndef NDEBUG
2606
2906
  GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2607
2907
  #endif
2608
2908
  }
2609
2909
 
2610
- if (node->op == GGML_OP_CPY) {
2611
-
2612
- // Store the pointers which are updated for each token, such that these can be sent
2613
- // to the device and accessed using indirection from CUDA graph
2614
- cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
2615
-
2616
- // store a pointer to each copy op CUDA kernel to identify it later
2617
- void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2618
- if (!ptr) {
2619
- use_cuda_graph = false;
2620
- #ifndef NDEBUG
2621
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
2622
- #endif
2623
- }
2624
- }
2625
-
2626
2910
  if (!use_cuda_graph) {
2627
2911
  break;
2628
2912
  }
2629
2913
  }
2630
2914
 
2631
- if (use_cuda_graph) {
2632
- cuda_ctx->cuda_graph->use_cpy_indirection = true;
2633
- // copy pointers to GPU so they can be accessed via indirection within CUDA graph
2634
- ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
2635
- }
2636
-
2637
2915
  return use_cuda_graph;
2638
2916
  }
2639
2917
 
2640
- static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2641
- graph_node_properties->node_address = node->data;
2642
- graph_node_properties->node_op = node->op;
2918
+ static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
2919
+ props->node_address = node->data;
2920
+ props->node_op = node->op;
2643
2921
  for (int i = 0; i < GGML_MAX_DIMS; i++) {
2644
- graph_node_properties->ne[i] = node->ne[i];
2645
- graph_node_properties->nb[i] = node->nb[i];
2922
+ props->ne[i] = node->ne[i];
2923
+ props->nb[i] = node->nb[i];
2646
2924
  }
2647
2925
  for (int i = 0; i < GGML_MAX_SRC; i++) {
2648
- graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
2926
+ props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
2649
2927
  }
2650
- memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
2928
+ memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
2651
2929
  }
2652
2930
 
2653
- static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2654
- if (node->data != graph_node_properties->node_address &&
2655
- node->op != GGML_OP_CPY &&
2931
+ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
2932
+ if (node->data != props->node_address &&
2656
2933
  node->op != GGML_OP_VIEW) {
2657
2934
  return false;
2658
2935
  }
2659
2936
 
2660
- if (node->op != graph_node_properties->node_op) {
2937
+ if (node->op != props->node_op) {
2661
2938
  return false;
2662
2939
  }
2663
2940
 
2664
2941
  for (int i = 0; i < GGML_MAX_DIMS; i++) {
2665
- if (node->ne[i] != graph_node_properties->ne[i]) {
2942
+ if (node->ne[i] != props->ne[i]) {
2666
2943
  return false;
2667
2944
  }
2668
- if (node->nb[i] != graph_node_properties->nb[i]) {
2945
+ if (node->nb[i] != props->nb[i]) {
2669
2946
  return false;
2670
2947
  }
2671
2948
  }
2672
2949
 
2673
2950
  for (int i = 0; i < GGML_MAX_SRC; i++) {
2674
2951
  if (node->src[i] &&
2675
- node->src[i]->data != graph_node_properties->src_address[i] &&
2676
- node->op != GGML_OP_CPY &&
2952
+ node->src[i]->data != props->src_address[i] &&
2677
2953
  node->op != GGML_OP_VIEW
2678
2954
  ) {
2679
2955
  return false;
2680
2956
  }
2681
2957
  }
2682
2958
 
2683
- if (node->op == GGML_OP_SCALE &&
2684
- memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2959
+ if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
2960
+ memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2685
2961
  return false;
2686
2962
  }
2687
2963
 
2688
2964
  return true;
2689
2965
  }
2690
2966
 
2691
- static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
2967
+ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
2692
2968
 
2693
- bool cuda_graph_update_required = false;
2969
+ bool res = false;
2694
2970
 
2695
2971
  if (cuda_ctx->cuda_graph->instance == nullptr) {
2696
- cuda_graph_update_required = true;
2972
+ res = true;
2697
2973
  }
2698
2974
 
2699
2975
  // Check if the graph size has changed
2700
- if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
2701
- cuda_graph_update_required = true;
2702
- cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2976
+ if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
2977
+ res = true;
2978
+ cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
2703
2979
  }
2704
2980
 
2705
2981
  // Loop over nodes in GGML graph to determine if CUDA graph update is required
2706
2982
  // and store properties to allow this comparison for the next token
2707
2983
  for (int i = 0; i < cgraph->n_nodes; i++) {
2708
- bool has_matching_properties = true;
2709
- if (!cuda_graph_update_required) {
2710
- has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2984
+ bool props_match = true;
2985
+ if (!res) {
2986
+ props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
2987
+ }
2988
+ if (!props_match) {
2989
+ res = true;
2711
2990
  }
2712
- if (!has_matching_properties) {
2713
- cuda_graph_update_required = true;
2991
+ ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
2992
+ }
2993
+
2994
+ for (int i = 0; i < cgraph->n_leafs; i++) {
2995
+ bool props_match= true;
2996
+ if (!res) {
2997
+ props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
2714
2998
  }
2715
- set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2999
+ if (!props_match) {
3000
+ res = true;
3001
+ }
3002
+ ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
2716
3003
  }
2717
3004
 
2718
- return cuda_graph_update_required;
3005
+ return res;
2719
3006
  }
2720
3007
 
2721
- static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
3008
+ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
2722
3009
 
2723
3010
  #if CUDART_VERSION >= 12000
2724
3011
  cudaGraphExecUpdateResultInfo result_info;
@@ -2746,154 +3033,745 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
2746
3033
  }
2747
3034
  #endif
2748
3035
 
2749
- static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2750
- bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
2751
- // flag used to determine whether it is an integrated_gpu
2752
- const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
2753
-
2754
- while (!graph_evaluated_or_captured) {
2755
- // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2756
- // With the use of CUDA graphs, the execution will be performed by the graph launch.
2757
- if (!use_cuda_graph || cuda_graph_update_required) {
2758
- for (int i = 0; i < cgraph->n_nodes; i++) {
2759
- ggml_tensor * node = cgraph->nodes[i];
2760
-
2761
- if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2762
- continue;
2763
- }
2764
-
2765
- #ifndef NDEBUG
2766
- assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
2767
- for (int j = 0; j < GGML_MAX_SRC; j++) {
2768
- if (node->src[j] != nullptr) {
2769
- assert(node->src[j]->buffer);
2770
- assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
2771
- ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
2772
- }
2773
- }
2774
- #else
2775
- GGML_UNUSED(integrated);
2776
- #endif // NDEBUG
3036
+ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
3037
+ const ggml_tensor * view,
3038
+ const ggml_tensor * set_rows) {
2777
3039
 
2778
- bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
2779
- if (!ok) {
2780
- GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
2781
- }
2782
- GGML_ASSERT(ok);
2783
- }
2784
- }
3040
+ if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) {
3041
+ return false;
3042
+ }
3043
+ // ne3 not tested
3044
+ if (rope->src[0]->ne[3] != 1) {
3045
+ return false;
3046
+ }
2785
3047
 
2786
- #ifdef USE_CUDA_GRAPH
2787
- if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
2788
- if (cuda_ctx->cuda_graph->graph != nullptr) {
2789
- CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
2790
- cuda_ctx->cuda_graph->graph = nullptr;
2791
- }
3048
+ if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
3049
+ return false;
3050
+ }
2792
3051
 
2793
- CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
2794
- graph_evaluated_or_captured = true; // CUDA graph has been captured
3052
+ if (set_rows->src[1]->type != GGML_TYPE_I64) {
3053
+ return false;
3054
+ }
2795
3055
 
2796
- std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2797
- if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
2798
- ggml_cuda_lock_cv.notify_all();
2799
- }
2800
- } else {
2801
- graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
2802
- }
3056
+ // The view should flatten two dims of rope into one dim
3057
+ if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) {
3058
+ return false;
2803
3059
  }
2804
3060
 
2805
- if (use_cuda_graph) {
2806
- if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
2807
- CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
2808
- }
2809
- if (cuda_graph_update_required) { // Update graph executable
2810
- update_cuda_graph_executable(cuda_ctx);
2811
- }
2812
- // Launch graph
2813
- CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
2814
- #else
2815
- graph_evaluated_or_captured = true;
2816
- #endif // USE_CUDA_GRAPH
3061
+ // Only norm/neox shaders have the fusion code
3062
+ const int mode = ((const int32_t *) rope->op_params)[2];
3063
+ if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
3064
+ return false;
2817
3065
  }
3066
+
3067
+ return true;
2818
3068
  }
2819
3069
 
2820
- static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
2821
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
3070
+ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
3071
+ #ifndef NDEBUG
3072
+ const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
3073
+ GGML_ASSERT(unary_ops.size() == num_unary);
3074
+ #endif
2822
3075
 
2823
- ggml_cuda_set_device(cuda_ctx->device);
3076
+ //TODO: remove special case once ggml_can_fuse can handle empty nodes
3077
+ std::initializer_list<enum ggml_op> topk_moe_ops =
3078
+ ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
3079
+ std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
3080
+ ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
3081
+ std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
3082
+ ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
3083
+
3084
+ const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
3085
+ const std::initializer_list<enum ggml_op> & list2) {
3086
+ return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
3087
+ };
2824
3088
 
2825
- #ifdef USE_CUDA_GRAPH
2826
- static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
3089
+ if (is_equal(topk_moe_ops_with_norm, ops) &&
3090
+ ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
3091
+ ggml_tensor * softmax = cgraph->nodes[node_idx];
3092
+ ggml_tensor * weights = cgraph->nodes[node_idx + 9];
3093
+ ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
3094
+ ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
3095
+ int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
2827
3096
 
2828
- // Objects required for CUDA Graph
2829
- if (cuda_ctx->cuda_graph == nullptr) {
2830
- cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
3097
+ if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
3098
+ return true;
3099
+ }
2831
3100
  }
2832
3101
 
2833
- bool use_cuda_graph = true;
2834
- bool cuda_graph_update_required = false;
3102
+ if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
3103
+ ggml_tensor * softmax = cgraph->nodes[node_idx];
3104
+ ggml_tensor * weights = cgraph->nodes[node_idx + 4];
3105
+ ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
3106
+ ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
3107
+ int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
2835
3108
 
2836
- if (cuda_ctx->cuda_graph->graph == nullptr) {
2837
- if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
2838
- cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
2839
- #ifndef NDEBUG
2840
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
2841
- #endif
3109
+ if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
3110
+ return true;
2842
3111
  }
2843
3112
  }
2844
3113
 
2845
- // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
2846
- // or previous graph capture failure.
2847
- // Also disable for multi-gpu for now. TO DO investigate
2848
- if (disable_cuda_graphs_due_to_env
2849
- || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
2850
- || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
2851
- || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
2852
- use_cuda_graph = false;
3114
+ if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
3115
+ ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
3116
+ ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
3117
+ ggml_tensor * weights = cgraph->nodes[node_idx + 5];
3118
+ ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
3119
+ ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
3120
+ int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
3121
+
3122
+ if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
3123
+ return true;
3124
+ }
2853
3125
  }
2854
3126
 
2855
- if (use_cuda_graph) {
2856
- cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
3127
+ std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
3128
+ std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
2857
3129
 
2858
- use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
3130
+ std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
3131
+ std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
2859
3132
 
2860
- // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
2861
- if (use_cuda_graph && cuda_graph_update_required) {
2862
- cuda_ctx->cuda_graph->number_consecutive_updates++;
2863
- } else {
2864
- cuda_ctx->cuda_graph->number_consecutive_updates = 0;
3133
+ if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&
3134
+ ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {
3135
+ const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
3136
+ const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
3137
+ const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
3138
+ const ggml_tensor * ffn_up_bias = cgraph->nodes[node_idx + 3];
3139
+ const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
3140
+
3141
+ if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
3142
+ return true;
2865
3143
  }
3144
+ }
2866
3145
 
2867
- if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
2868
- cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
2869
- #ifndef NDEBUG
2870
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
2871
- #endif
3146
+ if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&
3147
+ ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
3148
+ const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
3149
+ const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
3150
+ const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
3151
+
3152
+ if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
3153
+ return true;
2872
3154
  }
2873
3155
  }
2874
3156
 
2875
- if (use_cuda_graph && cuda_graph_update_required) {
2876
- // Start CUDA graph capture
2877
- {
3157
+ std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
3158
+
3159
+ if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
3160
+ const ggml_tensor * rope = cgraph->nodes[node_idx];
3161
+ const ggml_tensor * view = cgraph->nodes[node_idx + 1];
3162
+ const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
3163
+
3164
+ if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) {
3165
+ return true;
3166
+ }
3167
+ }
3168
+
3169
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
3170
+ return false;
3171
+ }
3172
+
3173
+ if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
3174
+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
3175
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
3176
+ const ggml_tensor *add = nullptr;
3177
+
3178
+ if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
3179
+ add = cgraph->nodes[node_idx+2];
3180
+ }
3181
+
3182
+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
3183
+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
3184
+
3185
+ //rms norm only supports F32
3186
+ if (mul->src[0]->type != GGML_TYPE_F32 ||
3187
+ mul->src[1]->type != GGML_TYPE_F32 ||
3188
+ mul->type != GGML_TYPE_F32) {
3189
+ return false;
3190
+ }
3191
+
3192
+ if (add && (add->src[0]->type != GGML_TYPE_F32 ||
3193
+ add->src[1]->type != GGML_TYPE_F32 ||
3194
+ add->type != GGML_TYPE_F32) ) {
3195
+ return false;
3196
+ }
3197
+
3198
+ //if rms norm is the B operand, then we don't handle broadcast
3199
+ if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
3200
+ return false;
3201
+ }
3202
+
3203
+ //rms_norm kernel assumes contigous rows
3204
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
3205
+ return false;
3206
+ }
3207
+
3208
+ if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
3209
+ return false;
3210
+ }
3211
+
3212
+ return true;
3213
+ }
3214
+
3215
+ if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
3216
+ && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
3217
+ const ggml_tensor *scale = cgraph->nodes[node_idx];
3218
+ const ggml_tensor *tanh = cgraph->nodes[node_idx+1];
3219
+ const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
3220
+
3221
+ GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
3222
+ GGML_ASSERT(scale->type == GGML_TYPE_F32);
3223
+
3224
+ if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {
3225
+ return false;
3226
+ }
3227
+
3228
+ // Check for bias
3229
+ if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
3230
+ return false;
3231
+ }
3232
+
3233
+ return true;
3234
+ }
3235
+
3236
+ return false;
3237
+ }
3238
+
3239
+ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
3240
+ bool graph_evaluated_or_captured = false;
3241
+
3242
+ // flag used to determine whether it is an integrated_gpu
3243
+ const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
3244
+
3245
+ ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
3246
+ bool is_concurrent_event_active = false;
3247
+ ggml_cuda_concurrent_event * concurrent_event = nullptr;
3248
+ bool should_launch_concurrent_events = false;
3249
+
3250
+ const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
3251
+ if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
3252
+ concurrent_event = &stream_ctx.concurrent_events[node];
3253
+
3254
+ is_concurrent_event_active = true;
3255
+
3256
+ GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
3257
+
3258
+ cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
3259
+ GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
3260
+ CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
3261
+
3262
+ for (int i = 1; i <= concurrent_event->n_streams; ++i) {
3263
+ cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
3264
+ CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
3265
+ }
3266
+ }
3267
+ };
3268
+
3269
+ while (!graph_evaluated_or_captured) {
3270
+ // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
3271
+ // With the use of CUDA graphs, the execution will be performed by the graph launch.
3272
+ if (!use_cuda_graph || cuda_graph_update_required) {
3273
+ [[maybe_unused]] int prev_i = 0;
3274
+
3275
+ if (stream_ctx.concurrent_events.size() > 0) {
3276
+ should_launch_concurrent_events = true;
3277
+ for (const auto & [tensor, event] : stream_ctx.concurrent_events) {
3278
+ should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
3279
+ }
3280
+ }
3281
+
3282
+ if (should_launch_concurrent_events) {
3283
+ // Restore original node order within each concurrent region to enable fusion within streams
3284
+
3285
+ std::unordered_map<const ggml_tensor *, int> node_to_idx;
3286
+ node_to_idx.reserve(cgraph->n_nodes);
3287
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
3288
+ node_to_idx[cgraph->nodes[i]] = i;
3289
+ }
3290
+
3291
+ for (auto & [fork_node, event] : stream_ctx.concurrent_events) {
3292
+ // Find positions of all nodes from this event in the current graph
3293
+ std::vector<int> positions;
3294
+ positions.reserve(event.original_order.size());
3295
+
3296
+ bool all_found = true;
3297
+ for (const ggml_tensor * orig_node : event.original_order) {
3298
+ auto it = node_to_idx.find(orig_node);
3299
+ if (it != node_to_idx.end()) {
3300
+ positions.push_back(it->second);
3301
+ } else {
3302
+ all_found = false;
3303
+ break;
3304
+ }
3305
+ }
3306
+
3307
+ if (!all_found || positions.size() != event.original_order.size()) {
3308
+ continue;
3309
+ }
3310
+
3311
+ // Sort positions to get contiguous range
3312
+ std::vector<int> sorted_positions = positions;
3313
+ std::sort(sorted_positions.begin(), sorted_positions.end());
3314
+
3315
+ bool is_contiguous = true;
3316
+ for (size_t i = 1; i < sorted_positions.size(); ++i) {
3317
+ if (sorted_positions[i] != sorted_positions[i-1] + 1) {
3318
+ is_contiguous = false;
3319
+ break;
3320
+ }
3321
+ }
3322
+
3323
+ if (!is_contiguous) {
3324
+ continue;
3325
+ }
3326
+
3327
+ // Restore original order at the sorted positions
3328
+ int start_pos = sorted_positions[0];
3329
+ for (size_t i = 0; i < event.original_order.size(); ++i) {
3330
+ cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
3331
+ }
3332
+ }
3333
+ } else {
3334
+ stream_ctx.concurrent_events.clear();
3335
+ }
3336
+
3337
+ for (int i = 0; i < cgraph->n_nodes; i++) {
3338
+ ggml_tensor * node = cgraph->nodes[i];
3339
+ if (is_concurrent_event_active) {
3340
+ GGML_ASSERT(concurrent_event);
3341
+
3342
+ if (node == concurrent_event->join_node) {
3343
+ cuda_ctx->curr_stream_no = 0;
3344
+ for (int i = 1; i <= concurrent_event->n_streams; ++i) {
3345
+ // Wait on join events of forked streams in the main stream
3346
+ CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1],
3347
+ cuda_ctx->stream(cuda_ctx->device, i)));
3348
+ CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1]));
3349
+ }
3350
+
3351
+ is_concurrent_event_active = false;
3352
+ concurrent_event = nullptr;
3353
+ } else {
3354
+ GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end());
3355
+ cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
3356
+ GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
3357
+ }
3358
+ } else if (i - prev_i > 1) {
3359
+ //the previous node was fused
3360
+ const ggml_tensor * prev_node = cgraph->nodes[i - 1];
3361
+ try_launch_concurrent_event(prev_node);
3362
+
3363
+ if (is_concurrent_event_active) {
3364
+ cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
3365
+ GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
3366
+ }
3367
+ }
3368
+
3369
+ #ifdef GGML_CUDA_DEBUG
3370
+ const int nodes_fused = i - prev_i - 1;
3371
+ if (nodes_fused > 0) {
3372
+ GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
3373
+ }
3374
+ #endif
3375
+ prev_i = i;
3376
+
3377
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
3378
+ continue;
3379
+ }
3380
+
3381
+
3382
+ // start of fusion operations
3383
+ static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
3384
+ if (!disable_fusion) {
3385
+
3386
+ if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
3387
+ ggml_tensor * weights = cgraph->nodes[i + 9];
3388
+ ggml_tensor * selected_experts = cgraph->nodes[i + 3];
3389
+ ggml_tensor * clamp = cgraph->nodes[i + 7];
3390
+ ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
3391
+ /*delayed softmax*/ false, clamp);
3392
+ i += 9;
3393
+ continue;
3394
+ }
3395
+
3396
+ if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
3397
+ ggml_tensor * weights = cgraph->nodes[i + 4];
3398
+ ggml_tensor * selected_experts = cgraph->nodes[i + 3];
3399
+ ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
3400
+ /*delayed softmax*/ false);
3401
+ i += 4;
3402
+ continue;
3403
+ }
3404
+
3405
+ if (ggml_cuda_can_fuse(cgraph, i,
3406
+ ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
3407
+ ggml_tensor * weights = cgraph->nodes[i + 5];
3408
+ ggml_tensor * ids = cgraph->nodes[i + 1];
3409
+
3410
+ ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
3411
+ /*delayed_softmax*/ true);
3412
+ i += 5;
3413
+ continue;
3414
+ }
3415
+
3416
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3417
+ ggml_tensor * rope = cgraph->nodes[i];
3418
+ ggml_tensor * set_rows = cgraph->nodes[i + 2];
3419
+
3420
+ ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
3421
+ i += 2;
3422
+ continue;
3423
+ }
3424
+
3425
+ if (node->op == GGML_OP_ADD) {
3426
+ int n_fuse = 0;
3427
+ ggml_op ops[8];
3428
+ std::fill(ops, ops + 8, GGML_OP_ADD);
3429
+
3430
+ for (; n_fuse <= 6; ++n_fuse){
3431
+ if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
3432
+ break;
3433
+ }
3434
+ if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
3435
+ break;
3436
+ }
3437
+ if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
3438
+ break;
3439
+ }
3440
+ }
3441
+
3442
+ n_fuse++;
3443
+
3444
+ if (n_fuse > 1) {
3445
+ for (int j = 0; j < n_fuse - 1; ++j) {
3446
+ node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
3447
+ }
3448
+ cgraph->nodes[i + n_fuse - 1]->data = node->data;
3449
+ ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
3450
+ i += n_fuse - 1;
3451
+
3452
+ continue;
3453
+ }
3454
+ }
3455
+
3456
+ bool fused_mul_mat_vec = false;
3457
+ int fused_node_count = 0;
3458
+
3459
+ for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3460
+ const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
3461
+
3462
+ if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
3463
+ ggml_tensor * glu = cgraph->nodes[i + 4];
3464
+ ggml_tensor * gate_bias_n = glu->src[0];
3465
+ ggml_tensor * up_bias_n = glu->src[1];
3466
+
3467
+ //we don't assume the order for {gate, up}. Instead infer it from the bias tensor
3468
+ ggml_tensor * gate_n = nullptr;
3469
+ ggml_tensor * up_n = nullptr;
3470
+
3471
+ if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
3472
+ gate_n = cgraph->nodes[i];
3473
+ up_n = cgraph->nodes[i + 2];
3474
+ } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
3475
+ gate_n = cgraph->nodes[i + 2];
3476
+ up_n = cgraph->nodes[i];
3477
+ } else {
3478
+ continue;
3479
+ }
3480
+
3481
+ auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
3482
+ if (op_bias == GGML_OP_ADD) {
3483
+ if (bias_node->src[0] == mul_node) {
3484
+ return bias_node->src[1];
3485
+ }
3486
+ if (bias_node->src[1] == mul_node) {
3487
+ return bias_node->src[0];
3488
+ }
3489
+ return (ggml_tensor *) nullptr;
3490
+ }
3491
+ GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
3492
+ GGML_ASSERT(bias_node->src[0] == mul_node);
3493
+ return bias_node->src[1];
3494
+ };
3495
+
3496
+ ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
3497
+ ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
3498
+
3499
+ if (!up_bias_tensor || !gate_bias_tensor) {
3500
+ continue;
3501
+ }
3502
+
3503
+ // we don't support repeating adds
3504
+ if (bias_op == GGML_OP_ADD &&
3505
+ (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
3506
+ !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
3507
+ continue;
3508
+ }
3509
+
3510
+ const ggml_tensor * src0 = up_n->src[0];
3511
+ const ggml_tensor * src1 = up_n->src[1];
3512
+ const ggml_tensor * ids = up_n->src[2];
3513
+
3514
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
3515
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3516
+ fusion_data.gate = gate_n->src[0];
3517
+ fusion_data.x_bias = up_bias_tensor;
3518
+ fusion_data.gate_bias = gate_bias_tensor;
3519
+ fusion_data.glu_op = ggml_get_glu_op(glu);
3520
+
3521
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3522
+ fused_mul_mat_vec = true;
3523
+ fused_node_count = 5;
3524
+ break;
3525
+ }
3526
+
3527
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
3528
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3529
+ fusion_data.gate = gate_n->src[0];
3530
+ fusion_data.x_bias = up_bias_tensor;
3531
+ fusion_data.gate_bias = gate_bias_tensor;
3532
+ fusion_data.glu_op = ggml_get_glu_op(glu);
3533
+
3534
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3535
+ fused_mul_mat_vec = true;
3536
+ fused_node_count = 5;
3537
+ break;
3538
+ }
3539
+ } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
3540
+ ggml_tensor * glu = cgraph->nodes[i + 2];
3541
+ ggml_tensor * gate = glu->src[0];
3542
+ ggml_tensor * up = glu->src[1];
3543
+
3544
+ bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
3545
+ || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
3546
+
3547
+ if (!ok) continue;
3548
+
3549
+ const ggml_tensor * src0 = up->src[0];
3550
+ const ggml_tensor * src1 = up->src[1];
3551
+ const ggml_tensor * ids = up->src[2];
3552
+
3553
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
3554
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3555
+ fusion_data.gate = gate->src[0];
3556
+ fusion_data.glu_op = ggml_get_glu_op(glu);
3557
+
3558
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3559
+ fused_mul_mat_vec = true;
3560
+ fused_node_count = 3;
3561
+ break;
3562
+ }
3563
+
3564
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
3565
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3566
+ fusion_data.gate = gate->src[0];
3567
+ fusion_data.glu_op = ggml_get_glu_op(glu);
3568
+
3569
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3570
+ fused_mul_mat_vec = true;
3571
+ fused_node_count = 3;
3572
+ break;
3573
+ }
3574
+ }
3575
+ }
3576
+
3577
+ if (fused_mul_mat_vec) {
3578
+ i += fused_node_count - 1;
3579
+ continue;
3580
+ }
3581
+
3582
+ fused_mul_mat_vec = false;
3583
+ fused_node_count = 0;
3584
+
3585
+ for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3586
+ const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
3587
+
3588
+ if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
3589
+ continue;
3590
+ }
3591
+
3592
+ ggml_tensor * mm_node = cgraph->nodes[i];
3593
+ ggml_tensor * bias_node = cgraph->nodes[i + 1];
3594
+
3595
+ ggml_tensor * bias_tensor = nullptr;
3596
+ if (bias_op == GGML_OP_ADD) {
3597
+ if (bias_node->src[0] == mm_node) {
3598
+ bias_tensor = bias_node->src[1];
3599
+ } else if (bias_node->src[1] == mm_node) {
3600
+ bias_tensor = bias_node->src[0];
3601
+ } else {
3602
+ continue;
3603
+ }
3604
+ } else {
3605
+ if (bias_node->src[0] != mm_node) {
3606
+ continue;
3607
+ }
3608
+ bias_tensor = bias_node->src[1];
3609
+ }
3610
+
3611
+ const ggml_tensor * src0 = mm_node->src[0];
3612
+ const ggml_tensor * src1 = mm_node->src[1];
3613
+ const ggml_tensor * ids = mm_node->src[2];
3614
+
3615
+ if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
3616
+ continue;
3617
+ }
3618
+
3619
+ if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
3620
+ continue;
3621
+ }
3622
+
3623
+ ggml_cuda_mm_fusion_args_host fusion_data{};
3624
+ fusion_data.x_bias = bias_tensor;
3625
+
3626
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
3627
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
3628
+ fused_mul_mat_vec = true;
3629
+ fused_node_count = 2;
3630
+ break;
3631
+ }
3632
+
3633
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
3634
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
3635
+ fused_mul_mat_vec = true;
3636
+ fused_node_count = 2;
3637
+ break;
3638
+ }
3639
+ }
3640
+
3641
+ if (fused_mul_mat_vec) {
3642
+ i += fused_node_count - 1;
3643
+ continue;
3644
+ }
3645
+
3646
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
3647
+ ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
3648
+ i += 2;
3649
+ continue;
3650
+ }
3651
+
3652
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
3653
+ ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
3654
+ i++;
3655
+ continue;
3656
+ }
3657
+
3658
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
3659
+ i += 2;
3660
+ ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
3661
+ continue;
3662
+ }
3663
+ }
3664
+ #ifndef NDEBUG
3665
+ assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
3666
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
3667
+ if (node->src[j] != nullptr) {
3668
+ assert(node->src[j]->buffer);
3669
+ assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
3670
+ ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
3671
+ }
3672
+ }
3673
+ #else
3674
+ GGML_UNUSED(integrated);
3675
+ #endif // NDEBUG
3676
+
3677
+ bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
3678
+ if (!ok) {
3679
+ GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
3680
+ }
3681
+ GGML_ASSERT(ok);
3682
+
3683
+ if (!is_concurrent_event_active) {
3684
+ try_launch_concurrent_event(node);
3685
+ }
3686
+ }
3687
+ }
3688
+
3689
+ #ifdef USE_CUDA_GRAPH
3690
+ if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
3691
+ if (cuda_ctx->cuda_graph->graph != nullptr) {
3692
+ CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
3693
+ cuda_ctx->cuda_graph->graph = nullptr;
3694
+ }
3695
+
3696
+ CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
3697
+ graph_evaluated_or_captured = true; // CUDA graph has been captured
3698
+
2878
3699
  std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2879
- ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
3700
+ if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
3701
+ ggml_cuda_lock_cv.notify_all();
3702
+ }
3703
+ } else {
3704
+ graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
2880
3705
  }
3706
+ }
2881
3707
 
2882
- CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
3708
+ if (use_cuda_graph) {
3709
+ if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
3710
+ CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
3711
+ }
3712
+ if (cuda_graph_update_required) { // Update graph executable
3713
+ ggml_cuda_graph_update_executable(cuda_ctx);
3714
+ }
3715
+ // Launch graph
3716
+ CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
3717
+ #else
3718
+ graph_evaluated_or_captured = true;
3719
+ #endif // USE_CUDA_GRAPH
3720
+ }
3721
+ }
3722
+
3723
+ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
3724
+
3725
+ #ifdef USE_CUDA_GRAPH
3726
+
3727
+ if (cuda_ctx->cuda_graph == nullptr) {
3728
+ cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
2883
3729
  }
2884
3730
 
2885
- if (!use_cuda_graph) {
2886
- cuda_ctx->cuda_graph->use_cpy_indirection = false;
3731
+ if (cuda_ctx->cuda_graph->graph == nullptr) {
3732
+ if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
3733
+ cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
3734
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
3735
+ }
2887
3736
  }
2888
3737
 
3738
+ return cuda_ctx->cuda_graph->is_enabled();
2889
3739
  #else
2890
- bool use_cuda_graph = false;
3740
+ GGML_UNUSED(cuda_ctx);
3741
+ return false;
3742
+ #endif // USE_CUDA_GRAPH
3743
+ }
3744
+
3745
+ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
3746
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
3747
+
3748
+ ggml_cuda_set_device(cuda_ctx->device);
3749
+
3750
+ bool use_cuda_graph = false;
2891
3751
  bool cuda_graph_update_required = false;
3752
+
3753
+ #ifdef USE_CUDA_GRAPH
3754
+ use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
3755
+
3756
+ if (cuda_ctx->cuda_graph->is_enabled()) {
3757
+ cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
3758
+ use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
3759
+
3760
+ cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
3761
+ }
2892
3762
  #endif // USE_CUDA_GRAPH
2893
3763
 
2894
- bool graph_evaluated_or_captured = false;
3764
+ if (use_cuda_graph && cuda_graph_update_required) {
3765
+ // Start CUDA graph capture
3766
+ {
3767
+ std::lock_guard<std::mutex> lock(ggml_cuda_lock);
3768
+ ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
3769
+ }
2895
3770
 
2896
- evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
3771
+ CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
3772
+ }
3773
+
3774
+ ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
2897
3775
 
2898
3776
  return GGML_STATUS_SUCCESS;
2899
3777
  }
@@ -2923,6 +3801,243 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
2923
3801
  }
2924
3802
  }
2925
3803
 
3804
+ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
3805
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
3806
+
3807
+ const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
3808
+
3809
+ static bool enable_graph_optimization = [] {
3810
+ const char * env = getenv("GGML_CUDA_GRAPH_OPT");
3811
+ return env != nullptr && atoi(env) == 1;
3812
+ }();
3813
+
3814
+ if (!enable_graph_optimization) {
3815
+ return;
3816
+ }
3817
+
3818
+ ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
3819
+ stream_context.reset();
3820
+
3821
+ if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) {
3822
+ return;
3823
+ }
3824
+
3825
+ // number of out-degrees for a particular node
3826
+ std::unordered_map<const ggml_tensor *, int> fan_out;
3827
+ // reverse mapping of node to index in the cgraph
3828
+ std::unordered_map<const ggml_tensor *, int> node_indices;
3829
+
3830
+ const auto & is_noop = [](const ggml_tensor * node) -> bool {
3831
+ return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE ||
3832
+ node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
3833
+ };
3834
+
3835
+ const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool {
3836
+ for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
3837
+ if (dst->src[s] == src) {
3838
+ return true;
3839
+ }
3840
+ }
3841
+ // implicit dependency if they view the same tensor
3842
+ const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst;
3843
+ const ggml_tensor * src2 = src->view_src ? src->view_src : src;
3844
+ if (dst2 == src2) {
3845
+ return true;
3846
+ }
3847
+ return false;
3848
+ };
3849
+
3850
+ for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
3851
+ const ggml_tensor * node = cgraph->nodes[node_idx];
3852
+ node_indices[node] = node_idx;
3853
+
3854
+ if (is_noop(node)) {
3855
+ continue;
3856
+ }
3857
+ for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3858
+ const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx];
3859
+ //TODO: check why nrows > 1 fails
3860
+ if (node && !is_noop(node) && ggml_nrows(node) <= 1) {
3861
+ fan_out[src] += 1;
3862
+ }
3863
+ }
3864
+ }
3865
+
3866
+ // Target Q, K, V for concurrency
3867
+ // this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else):
3868
+ // 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm")
3869
+ // 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn")
3870
+ // 3. account for all branches from the fork to the join
3871
+ // 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details)
3872
+ // 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams
3873
+ // See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030
3874
+
3875
+ const int min_fan_out = 3;
3876
+ const int max_fan_out = 3;
3877
+
3878
+ // store {fork_idx, join_idx}
3879
+ std::vector<std::pair<int, int>> concurrent_node_ranges;
3880
+
3881
+ for (const auto & [root_node, count] : fan_out) {
3882
+ if (count >= min_fan_out && count <= max_fan_out) {
3883
+ const int root_node_idx = node_indices[root_node];
3884
+
3885
+ // only optimize for attn_norm
3886
+ // TODO: make this more generic
3887
+ if (!strstr(root_node->name, "attn_norm")) {
3888
+ continue;
3889
+ }
3890
+
3891
+ bool is_part_of_event = false;
3892
+ for (const auto & [start, end] : concurrent_node_ranges) {
3893
+ if (root_node_idx >= start && root_node_idx <= end) {
3894
+ is_part_of_event = true;
3895
+ }
3896
+ }
3897
+
3898
+ if (is_part_of_event) {
3899
+ continue;
3900
+ }
3901
+
3902
+ std::vector<std::vector<const ggml_tensor *>> nodes_per_branch;
3903
+ for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
3904
+ const ggml_tensor * node = cgraph->nodes[i];
3905
+ if (!is_noop(node) && depends_on(node, root_node)) {
3906
+ nodes_per_branch.push_back({ node });
3907
+ }
3908
+ }
3909
+
3910
+ GGML_ASSERT(nodes_per_branch.size() == (size_t) count);
3911
+
3912
+ //find the join point
3913
+ const ggml_tensor * join_node = nullptr;
3914
+
3915
+ const auto & belongs_to_branch = [&](const ggml_tensor * node,
3916
+ const std::vector<const ggml_tensor *> & branch) -> bool {
3917
+ for (const ggml_tensor * n : branch) {
3918
+ if (depends_on(node, n)) {
3919
+ return true;
3920
+ }
3921
+ }
3922
+ return false;
3923
+ };
3924
+
3925
+ for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
3926
+ const ggml_tensor * curr_node = cgraph->nodes[i];
3927
+
3928
+ int num_joins = 0;
3929
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
3930
+ if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) {
3931
+ num_joins++;
3932
+ }
3933
+ }
3934
+
3935
+ if (num_joins >= 2) {
3936
+ join_node = curr_node;
3937
+ break;
3938
+ }
3939
+
3940
+ bool found_branch = false;
3941
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
3942
+ std::vector<const ggml_tensor *> & branch_vec = nodes_per_branch[branch_idx];
3943
+ if (belongs_to_branch(curr_node, branch_vec)) {
3944
+ //continue accumulating
3945
+ if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) {
3946
+ branch_vec.push_back(curr_node);
3947
+ }
3948
+ found_branch = true;
3949
+ }
3950
+ }
3951
+
3952
+ if (!found_branch && is_noop(curr_node)) {
3953
+ // we can put it in any branch because it will be ignored
3954
+ nodes_per_branch[0].push_back({ curr_node });
3955
+ }
3956
+ }
3957
+
3958
+ if (join_node) {
3959
+ //Create ggml_cuda_concurrent_event
3960
+ ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size());
3961
+ concurrent_event.join_node = join_node;
3962
+
3963
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
3964
+ for (const ggml_tensor * n : nodes_per_branch[branch_idx]) {
3965
+ concurrent_event.stream_mapping[n] = branch_idx + 1;
3966
+ }
3967
+ }
3968
+
3969
+ int fork_node_idx = node_indices[root_node];
3970
+ int join_node_idx = node_indices[join_node];
3971
+
3972
+ int current_branch_idx = 0;
3973
+ int current_node_idx = fork_node_idx + 1;
3974
+ const int n_branches = nodes_per_branch.size();
3975
+
3976
+ int total_branch_nodes = 0;
3977
+ for (std::vector<const ggml_tensor *> branch_nodes : nodes_per_branch) {
3978
+ total_branch_nodes += branch_nodes.size();
3979
+ }
3980
+
3981
+ // there are other nodes in the middle which are unaccounted for
3982
+ // usually (cpy) nodes, then ignore this fork
3983
+ if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) {
3984
+ GGML_LOG_DEBUG(
3985
+ "Skipping %s because the number of nodes in the middle is not equal to the total number of "
3986
+ "branch nodes %d != %d\n",
3987
+ root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes);
3988
+ continue;
3989
+ }
3990
+
3991
+ // Save the original order of nodes in this region before interleaving
3992
+ // This is used later to restore grouping for fusion within streams
3993
+ concurrent_event.original_order.reserve(total_branch_nodes);
3994
+ for (int i = fork_node_idx + 1; i < join_node_idx; ++i) {
3995
+ concurrent_event.original_order.push_back(cgraph->nodes[i]);
3996
+ }
3997
+
3998
+ std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
3999
+ GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
4000
+ concurrent_events.emplace(root_node, std::move(concurrent_event));
4001
+ GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node);
4002
+ concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx);
4003
+
4004
+ // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
4005
+ // example transformation:
4006
+ // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->
4007
+ // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]
4008
+ while (current_node_idx < join_node_idx) {
4009
+ std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];
4010
+
4011
+ bool has_node = false;
4012
+ for (std::vector<const ggml_tensor *> branch_node : nodes_per_branch) {
4013
+ has_node |= branch_node.size() > 0;
4014
+ }
4015
+
4016
+ GGML_ASSERT(has_node);
4017
+
4018
+ if (branch_nodes.empty()) {
4019
+ current_branch_idx = (current_branch_idx + 1) % n_branches;
4020
+ continue;
4021
+ }
4022
+
4023
+ cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
4024
+ current_node_idx++;
4025
+ branch_nodes.erase(branch_nodes.begin());
4026
+
4027
+ // append all empty nodes
4028
+ while (!branch_nodes.empty() && is_noop(branch_nodes.front())) {
4029
+ cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
4030
+ current_node_idx++;
4031
+ branch_nodes.erase(branch_nodes.begin());
4032
+ }
4033
+
4034
+ current_branch_idx = (current_branch_idx + 1) % n_branches;
4035
+ }
4036
+ }
4037
+ }
4038
+ }
4039
+ }
4040
+
2926
4041
  static const ggml_backend_i ggml_backend_cuda_interface = {
2927
4042
  /* .get_name = */ ggml_backend_cuda_get_name,
2928
4043
  /* .free = */ ggml_backend_cuda_free,
@@ -2937,6 +4052,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
2937
4052
  /* .graph_compute = */ ggml_backend_cuda_graph_compute,
2938
4053
  /* .event_record = */ ggml_backend_cuda_event_record,
2939
4054
  /* .event_wait = */ ggml_backend_cuda_event_wait,
4055
+ /* .graph_optimize = */ ggml_backend_cuda_graph_optimize,
2940
4056
  };
2941
4057
 
2942
4058
  static ggml_guid_t ggml_backend_cuda_guid() {
@@ -2969,7 +4085,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
2969
4085
  return false;
2970
4086
  }
2971
4087
 
2972
- #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
4088
+ #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP)
2973
4089
  cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
2974
4090
  if (err != cudaSuccess) {
2975
4091
  // clear the error
@@ -3006,6 +4122,8 @@ struct ggml_backend_cuda_device_context {
3006
4122
  int device;
3007
4123
  std::string name;
3008
4124
  std::string description;
4125
+ std::string pci_bus_id;
4126
+ int op_offload_min_batch_size;
3009
4127
  };
3010
4128
 
3011
4129
  static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -3018,10 +4136,110 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
3018
4136
  return ctx->description.c_str();
3019
4137
  }
3020
4138
 
4139
+ #if defined(__linux__)
4140
+ // Helper function to get available memory from /proc/meminfo for UMA systems
4141
+ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) {
4142
+ FILE * meminfo_file = nullptr;
4143
+ // 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough
4144
+ const size_t BUFFER_SIZE = 2048;
4145
+ auto file_buffer = std::make_unique<char[]>(BUFFER_SIZE);
4146
+ size_t bytes_read = 0;
4147
+ long huge_tlb_total_pages = -1;
4148
+ long huge_tlb_free_pages = -1;
4149
+ long huge_tlb_page_size = -1;
4150
+
4151
+ if (available_memory_kb == nullptr || free_swap_kb == nullptr) {
4152
+ return false;
4153
+ }
4154
+
4155
+ meminfo_file = fopen("/proc/meminfo", "r");
4156
+ if (meminfo_file == nullptr) {
4157
+ GGML_LOG_ERROR("%s: failed to open /proc/meminfo\n", __func__);
4158
+ return false;
4159
+ }
4160
+
4161
+ // Read file into buffer
4162
+ bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file);
4163
+ fclose(meminfo_file);
4164
+
4165
+ if (bytes_read == 0) {
4166
+ GGML_LOG_ERROR("%s: failed to read from /proc/meminfo\n", __func__);
4167
+ return false;
4168
+ }
4169
+ file_buffer[bytes_read] = '\0';
4170
+
4171
+ *available_memory_kb = -1;
4172
+ *free_swap_kb = -1;
4173
+
4174
+ // Parse the file buffer line by line
4175
+ char * line = file_buffer.get();
4176
+ char * line_next;
4177
+ while (line < file_buffer.get() + bytes_read) {
4178
+ // Find the end of the current line
4179
+ line_next = strchr(line, '\n');
4180
+ if (line_next != nullptr) {
4181
+ *line_next = '\0';
4182
+ line_next++;
4183
+ } else {
4184
+ line_next = file_buffer.get() + bytes_read;
4185
+ }
4186
+
4187
+ long value;
4188
+ if (sscanf(line, "MemAvailable: %ld kB", &value) == 1) {
4189
+ *available_memory_kb = value;
4190
+ } else if (sscanf(line, "SwapFree: %ld kB", &value) == 1) {
4191
+ *free_swap_kb = value;
4192
+ } else if (sscanf(line, "HugePages_Total: %ld", &value) == 1) {
4193
+ huge_tlb_total_pages = value;
4194
+ } else if (sscanf(line, "HugePages_Free: %ld", &value) == 1) {
4195
+ huge_tlb_free_pages = value;
4196
+ } else if (sscanf(line, "Hugepagesize: %ld kB", &value) == 1) {
4197
+ huge_tlb_page_size = value;
4198
+ }
4199
+
4200
+ line = line_next;
4201
+ }
4202
+
4203
+ if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) {
4204
+ *available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size;
4205
+
4206
+ // Hugetlbfs pages are not swappable.
4207
+ *free_swap_kb = 0;
4208
+ }
4209
+
4210
+ GGML_LOG_DEBUG("%s: final available_memory_kb: %ld\n", __func__, *available_memory_kb);
4211
+ return true;
4212
+ }
4213
+ #endif // defined(__linux__)
4214
+
3021
4215
  static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
3022
4216
  ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
3023
4217
  ggml_cuda_set_device(ctx->device);
3024
4218
  CUDA_CHECK(cudaMemGetInfo(free, total));
4219
+
4220
+ // ref: https://github.com/ggml-org/llama.cpp/pull/17368
4221
+ #if defined(__linux__)
4222
+ // Check if this is a UMA (Unified Memory Architecture) system
4223
+ cudaDeviceProp prop;
4224
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));
4225
+
4226
+ // Check if UMA is explicitly enabled via environment variable
4227
+ bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
4228
+ bool is_uma = prop.integrated > 0 || uma_env;
4229
+
4230
+ if (is_uma) {
4231
+ // For UMA systems (like DGX Spark), use system memory info
4232
+ long available_memory_kb = 0;
4233
+ long free_swap_kb = 0;
4234
+
4235
+ if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) {
4236
+ *free = (size_t)available_memory_kb * 1024;
4237
+ } else {
4238
+ GGML_LOG_ERROR("%s: /proc/meminfo reading failed, using cudaMemGetInfo\n", __func__);
4239
+ }
4240
+ }
4241
+ #endif // defined(__linux__)
4242
+
3025
4243
  }
3026
4244
 
3027
4245
  static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
@@ -3030,9 +4248,12 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
3030
4248
  }
3031
4249
 
3032
4250
  static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
4251
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
4252
+
3033
4253
  props->name = ggml_backend_cuda_device_get_name(dev);
3034
4254
  props->description = ggml_backend_cuda_device_get_description(dev);
3035
4255
  props->type = ggml_backend_cuda_device_get_type(dev);
4256
+ props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
3036
4257
  ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
3037
4258
 
3038
4259
  bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
@@ -3106,6 +4327,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3106
4327
  case GGML_UNARY_OP_GELU_QUICK:
3107
4328
  case GGML_UNARY_OP_TANH:
3108
4329
  case GGML_UNARY_OP_EXP:
4330
+ case GGML_UNARY_OP_EXPM1:
4331
+ case GGML_UNARY_OP_SOFTPLUS:
4332
+ case GGML_UNARY_OP_ELU:
4333
+ case GGML_UNARY_OP_XIELU:
4334
+ case GGML_UNARY_OP_FLOOR:
4335
+ case GGML_UNARY_OP_CEIL:
4336
+ case GGML_UNARY_OP_ROUND:
4337
+ case GGML_UNARY_OP_TRUNC:
3109
4338
  return ggml_is_contiguous(op->src[0]);
3110
4339
  default:
3111
4340
  return false;
@@ -3116,6 +4345,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3116
4345
  case GGML_GLU_OP_REGLU:
3117
4346
  case GGML_GLU_OP_GEGLU:
3118
4347
  case GGML_GLU_OP_SWIGLU:
4348
+ case GGML_GLU_OP_SWIGLU_OAI:
4349
+ case GGML_GLU_OP_GEGLU_ERF:
4350
+ case GGML_GLU_OP_GEGLU_QUICK:
3119
4351
  return ggml_is_contiguous_1(op->src[0]);
3120
4352
  default:
3121
4353
  return false;
@@ -3164,6 +4396,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3164
4396
  case GGML_TYPE_Q5_0:
3165
4397
  case GGML_TYPE_Q5_1:
3166
4398
  case GGML_TYPE_Q8_0:
4399
+ case GGML_TYPE_MXFP4:
3167
4400
  case GGML_TYPE_Q2_K:
3168
4401
  case GGML_TYPE_Q3_K:
3169
4402
  case GGML_TYPE_Q4_K:
@@ -3192,6 +4425,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3192
4425
  switch (op->src[0]->type) {
3193
4426
  case GGML_TYPE_F16:
3194
4427
  case GGML_TYPE_F32:
4428
+ case GGML_TYPE_BF16:
4429
+ case GGML_TYPE_I32:
3195
4430
  case GGML_TYPE_Q4_0:
3196
4431
  case GGML_TYPE_Q4_1:
3197
4432
  case GGML_TYPE_Q5_0:
@@ -3206,17 +4441,28 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3206
4441
  {
3207
4442
  return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
3208
4443
  } break;
4444
+ case GGML_OP_SET_ROWS:
4445
+ {
4446
+ return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
4447
+ op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
4448
+ op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
4449
+ op->src[0]->type == GGML_TYPE_F32 &&
4450
+ (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
4451
+ } break;
4452
+ case GGML_OP_SET:
4453
+ {
4454
+ const ggml_type t = op->type;
4455
+ return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
4456
+ t == op->src[0]->type &&
4457
+ t == op->src[1]->type;
4458
+ } break;
3209
4459
  case GGML_OP_CPY:
3210
4460
  {
3211
4461
  ggml_type src0_type = op->src[0]->type;
3212
4462
  ggml_type src1_type = op->src[1]->type;
3213
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
3214
- return true;
3215
- }
3216
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
3217
- return true;
3218
- }
3219
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
4463
+ if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) &&
4464
+ (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16)
4465
+ ) {
3220
4466
  return true;
3221
4467
  }
3222
4468
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
@@ -3252,10 +4498,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3252
4498
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
3253
4499
  return true;
3254
4500
  }
3255
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
4501
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) {
4502
+ return true;
4503
+ }
4504
+ if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
3256
4505
  return true;
3257
4506
  }
3258
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
4507
+ if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) {
3259
4508
  return true;
3260
4509
  }
3261
4510
  if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
@@ -3310,6 +4559,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3310
4559
  case GGML_OP_PERMUTE:
3311
4560
  case GGML_OP_TRANSPOSE:
3312
4561
  case GGML_OP_ADD:
4562
+ case GGML_OP_ADD_ID:
3313
4563
  case GGML_OP_ADD1:
3314
4564
  case GGML_OP_SUB:
3315
4565
  case GGML_OP_MUL:
@@ -3321,12 +4571,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3321
4571
  case GGML_OP_COS:
3322
4572
  case GGML_OP_CLAMP:
3323
4573
  case GGML_OP_LOG:
3324
- case GGML_OP_SSM_SCAN:
3325
- case GGML_OP_SSM_CONV:
3326
4574
  return true;
4575
+ case GGML_OP_SSM_SCAN: {
4576
+ if (op->src[3]->ne[0] == 1) {
4577
+ // Mamba2
4578
+ // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
4579
+ return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
4580
+ } else {
4581
+ // Mamba
4582
+ // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
4583
+ return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
4584
+ }
4585
+ }
4586
+ case GGML_OP_SSM_CONV: {
4587
+ // assumes d_inner % threads == 0
4588
+ return op->src[0]->ne[1] % 128 == 0;
4589
+ }
3327
4590
  case GGML_OP_CONT:
3328
- return op->src[0]->type != GGML_TYPE_BF16;
4591
+ return true;
3329
4592
  case GGML_OP_DIAG_MASK_INF:
4593
+ return true;
3330
4594
  case GGML_OP_SOFT_MAX:
3331
4595
  return true;
3332
4596
  case GGML_OP_SOFT_MAX_BACK: {
@@ -3334,25 +4598,39 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3334
4598
  memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
3335
4599
  return max_bias == 0.0f;
3336
4600
  }
4601
+ case GGML_OP_ROLL:
4602
+ if(op->src[0]->type == GGML_TYPE_F32) {
4603
+ return true;
4604
+ }
4605
+ return false;
3337
4606
  case GGML_OP_ROPE:
3338
4607
  case GGML_OP_ROPE_BACK: {
3339
4608
  return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
3340
4609
  }
3341
4610
  case GGML_OP_IM2COL:
4611
+ case GGML_OP_IM2COL_3D:
4612
+ case GGML_OP_CONV_2D:
3342
4613
  case GGML_OP_CONV_2D_DW:
3343
4614
  case GGML_OP_CONV_TRANSPOSE_2D:
3344
4615
  case GGML_OP_POOL_2D:
4616
+ case GGML_OP_ACC:
4617
+ return true;
3345
4618
  case GGML_OP_SUM:
3346
- case GGML_OP_SUM_ROWS:
3347
- case GGML_OP_MEAN:
4619
+ return ggml_is_contiguous_rows(op->src[0]);
4620
+ case GGML_OP_TOP_K:
3348
4621
  case GGML_OP_ARGSORT:
3349
- case GGML_OP_ACC:
4622
+ #ifndef GGML_CUDA_USE_CUB
4623
+ return op->src[0]->ne[0] <= 1024;
4624
+ #else
3350
4625
  return true;
4626
+ #endif
4627
+ case GGML_OP_SUM_ROWS:
4628
+ case GGML_OP_MEAN:
3351
4629
  case GGML_OP_GROUP_NORM:
4630
+ case GGML_OP_PAD:
3352
4631
  return ggml_is_contiguous(op->src[0]);
3353
4632
  case GGML_OP_UPSCALE:
3354
- return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
3355
- case GGML_OP_PAD:
4633
+ case GGML_OP_PAD_REFLECT_1D:
3356
4634
  case GGML_OP_ARANGE:
3357
4635
  case GGML_OP_TIMESTEP_EMBEDDING:
3358
4636
  case GGML_OP_LEAKY_RELU:
@@ -3360,43 +4638,19 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3360
4638
  case GGML_OP_GATED_LINEAR_ATTN:
3361
4639
  case GGML_OP_RWKV_WKV7:
3362
4640
  return true;
3363
- case GGML_OP_FLASH_ATTN_EXT: {
3364
- #ifndef FLASH_ATTN_AVAILABLE
3365
- return false;
3366
- #endif // FLASH_ATTN_AVAILABLE
3367
- if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3368
- const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3369
- if (!new_mma_available(cc)) {
3370
- return false;
3371
- }
3372
- const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
3373
- return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
3374
- }
3375
- if (op->src[0]->ne[0] == 192) {
3376
- return false;
3377
- }
3378
- if (op->src[0]->ne[3] != 1) {
3379
- return false;
3380
- }
3381
- if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
3382
- return false;
3383
- }
3384
- if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
3385
- return true;
3386
- }
3387
- if (op->src[0]->ne[0] == 128) {
3388
- return true;
3389
- }
3390
- if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
3391
- return true;
3392
- }
3393
- return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
3394
- op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
3395
- }
4641
+ case GGML_OP_FLASH_ATTN_EXT:
4642
+ return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
3396
4643
  case GGML_OP_CROSS_ENTROPY_LOSS:
3397
4644
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
3398
4645
  case GGML_OP_OPT_STEP_ADAMW:
4646
+ case GGML_OP_OPT_STEP_SGD:
4647
+ case GGML_OP_FILL:
4648
+ case GGML_OP_CUMSUM:
4649
+ case GGML_OP_TRI:
4650
+ case GGML_OP_DIAG:
4651
+ case GGML_OP_SOLVE_TRI:
3399
4652
  return true;
4653
+
3400
4654
  default:
3401
4655
  return false;
3402
4656
  }
@@ -3424,11 +4678,9 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
3424
4678
  }
3425
4679
 
3426
4680
  static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
3427
- const int min_batch_size = 32;
3428
-
3429
- return get_op_batch_size(op) >= min_batch_size;
4681
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
3430
4682
 
3431
- GGML_UNUSED(dev);
4683
+ return get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;
3432
4684
  }
3433
4685
 
3434
4686
  static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) {
@@ -3527,10 +4779,6 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
3527
4779
  features.push_back({ "NO_PEER_COPY", "1" });
3528
4780
  #endif
3529
4781
 
3530
- #ifdef GGML_CUDA_F16
3531
- features.push_back({ "F16", "1" });
3532
- #endif
3533
-
3534
4782
  #ifdef GGML_CUDA_USE_GRAPHS
3535
4783
  features.push_back({ "USE_GRAPHS", "1" });
3536
4784
  #endif
@@ -3543,6 +4791,16 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
3543
4791
  features.push_back({ "FA_ALL_QUANTS", "1" });
3544
4792
  #endif
3545
4793
 
4794
+ {
4795
+ const auto & info = ggml_cuda_info();
4796
+ for (int id = 0; id < info.device_count; ++id) {
4797
+ if (blackwell_mma_available(info.devices[id].cc)) {
4798
+ features.push_back({ "BLACKWELL_NATIVE_FP4", "1"});
4799
+ break;
4800
+ }
4801
+ }
4802
+ }
4803
+
3546
4804
  #undef _STRINGIFY
3547
4805
  #undef STRINGIFY
3548
4806
 
@@ -3590,17 +4848,22 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
3590
4848
  std::lock_guard<std::mutex> lock(mutex);
3591
4849
  if (!initialized) {
3592
4850
  ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
4851
+ const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
3593
4852
 
3594
4853
  for (int i = 0; i < ggml_cuda_info().device_count; i++) {
3595
4854
  ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
3596
4855
  dev_ctx->device = i;
3597
4856
  dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
3598
4857
 
3599
- ggml_cuda_set_device(i);
3600
4858
  cudaDeviceProp prop;
3601
4859
  CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
3602
4860
  dev_ctx->description = prop.name;
3603
4861
 
4862
+ char pci_bus_id[16] = {};
4863
+ snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
4864
+ dev_ctx->pci_bus_id = pci_bus_id;
4865
+ dev_ctx->op_offload_min_batch_size = min_batch_size;
4866
+
3604
4867
  ggml_backend_dev_t dev = new ggml_backend_device {
3605
4868
  /* .iface = */ ggml_backend_cuda_device_interface,
3606
4869
  /* .reg = */ &reg,
@@ -3635,10 +4898,10 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
3635
4898
  }
3636
4899
 
3637
4900
  ggml_backend_t cuda_backend = new ggml_backend {
3638
- /* .guid = */ ggml_backend_cuda_guid(),
3639
- /* .interface = */ ggml_backend_cuda_interface,
3640
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
3641
- /* .context = */ ctx,
4901
+ /* .guid = */ ggml_backend_cuda_guid(),
4902
+ /* .iface = */ ggml_backend_cuda_interface,
4903
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
4904
+ /* .context = */ ctx,
3642
4905
  };
3643
4906
 
3644
4907
  return cuda_backend;