whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -2,6 +2,7 @@
2
2
  #include "ggml-impl.h"
3
3
  #include "ggml-backend-impl.h"
4
4
 
5
+ #include "ggml-cuda/allreduce.cuh"
5
6
  #include "ggml-cuda/common.cuh"
6
7
  #include "ggml-cuda/acc.cuh"
7
8
  #include "ggml-cuda/add-id.cuh"
@@ -23,6 +24,7 @@
23
24
  #include "ggml-cuda/diagmask.cuh"
24
25
  #include "ggml-cuda/diag.cuh"
25
26
  #include "ggml-cuda/fattn.cuh"
27
+ #include "ggml-cuda/fwht.cuh"
26
28
  #include "ggml-cuda/getrows.cuh"
27
29
  #include "ggml-cuda/im2col.cuh"
28
30
  #include "ggml-cuda/mmf.cuh"
@@ -39,6 +41,7 @@
39
41
  #include "ggml-cuda/rope.cuh"
40
42
  #include "ggml-cuda/roll.cuh"
41
43
  #include "ggml-cuda/scale.cuh"
44
+ #include "ggml-cuda/snake.cuh"
42
45
  #include "ggml-cuda/softcap.cuh"
43
46
  #include "ggml-cuda/softmax.cuh"
44
47
  #include "ggml-cuda/ssm-conv.cuh"
@@ -82,10 +85,12 @@
82
85
  #include <cstdlib>
83
86
  #include <string>
84
87
  #include <vector>
85
- #include <unordered_set>
86
88
 
87
89
  static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
88
90
 
91
+ #define GGML_LOG_WARN_ONCE(str) \
92
+ { static std::once_flag warn_flag; std::call_once(warn_flag, []() { GGML_LOG_WARN(str); }); }
93
+
89
94
  [[noreturn]]
90
95
  void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
91
96
  int id = -1; // in case cudaGetDevice fails
@@ -126,7 +131,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
126
131
  if (err == hipSuccess) {
127
132
  // hipMemAdviseSetCoarseGrain is an optional performance hint;
128
133
  // ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs).
129
- cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);
134
+ (void)cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);
130
135
  (void)hipGetLastError(); // clear any error
131
136
  }
132
137
 
@@ -325,6 +330,22 @@ static ggml_cuda_device_info ggml_cuda_init() {
325
330
  // configure logging to stdout
326
331
  // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
327
332
 
333
+ if (getenv("GGML_CUDA_P2P") != nullptr) {
334
+ for (int id = 0; id < info.device_count; ++id) {
335
+ ggml_cuda_set_device(id);
336
+ for (int id_other = 0; id_other < info.device_count; ++id_other) {
337
+ if (id == id_other) {
338
+ continue;
339
+ }
340
+ int can_access_peer;
341
+ CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
342
+ if (can_access_peer) {
343
+ CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0));
344
+ }
345
+ }
346
+ }
347
+ }
348
+
328
349
  return info;
329
350
  }
330
351
 
@@ -353,15 +374,21 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
353
374
  }
354
375
 
355
376
  ~ggml_cuda_pool_leg() {
377
+ clear_pool();
378
+ GGML_ASSERT(pool_size == 0);
379
+ }
380
+
381
+ void clear_pool() {
356
382
  ggml_cuda_set_device(device);
357
383
  for (int i = 0; i < MAX_BUFFERS; ++i) {
358
384
  ggml_cuda_buffer & b = buffer_pool[i];
359
385
  if (b.ptr != nullptr) {
360
386
  CUDA_CHECK(cudaFree(b.ptr));
361
387
  pool_size -= b.size;
388
+ b.ptr = nullptr;
389
+ b.size = 0;
362
390
  }
363
391
  }
364
- GGML_ASSERT(pool_size == 0);
365
392
  }
366
393
 
367
394
  void * alloc(size_t size, size_t * actual_size) override {
@@ -406,7 +433,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
406
433
  size_t look_ahead_size = (size_t) (1.05 * size);
407
434
  look_ahead_size = 256 * ((look_ahead_size + 255)/256);
408
435
  ggml_cuda_set_device(device);
409
- CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
436
+ cudaError_t err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device);
437
+ if (err == cudaErrorMemoryAllocation) {
438
+ (void)cudaGetLastError();
439
+ const size_t cached_bytes = pool_size;
440
+ GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: alloc of %.2f MiB failed, flushing %.2f MiB of cached buffers and retrying\n",
441
+ device, look_ahead_size/1024.0/1024.0, cached_bytes/1024.0/1024.0);
442
+ CUDA_CHECK(cudaDeviceSynchronize());
443
+ clear_pool();
444
+ err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device);
445
+ if (err == cudaSuccess) {
446
+ GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: retry succeeded\n", device);
447
+ }
448
+ }
449
+ CUDA_CHECK(err);
410
450
  *actual_size = look_ahead_size;
411
451
  pool_size += look_ahead_size;
412
452
  #ifdef DEBUG_CUDA_MALLOC
@@ -582,6 +622,18 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() {
582
622
 
583
623
  // cuda buffer
584
624
 
625
+ struct ggml_backend_cuda_device_context {
626
+ int device;
627
+ std::string name;
628
+ std::string description;
629
+ std::string pci_bus_id;
630
+ int op_offload_min_batch_size;
631
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
632
+ std::mutex device_mutex;
633
+ int active_count = 0;
634
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
635
+ };
636
+
585
637
  struct ggml_backend_cuda_buffer_context {
586
638
  int device;
587
639
  void * dev_ptr = nullptr;
@@ -599,6 +651,13 @@ struct ggml_backend_cuda_buffer_context {
599
651
 
600
652
  static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
601
653
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
654
+
655
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
656
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context;
657
+ std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
658
+ dev_ctx->active_count--;
659
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
660
+
602
661
  delete ctx;
603
662
  }
604
663
 
@@ -633,26 +692,46 @@ static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer
633
692
  }
634
693
 
635
694
  static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
636
- ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
695
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context;
637
696
 
638
697
  ggml_cuda_set_device(ctx->device);
639
- CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
698
+ CUDA_CHECK(cudaMemsetAsync((char *) tensor->data + offset, value, size, cudaStreamPerThread));
640
699
  CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
641
700
  }
642
701
 
643
702
  static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
644
- ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
703
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context;
645
704
 
646
705
  ggml_cuda_set_device(ctx->device);
647
- CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
706
+ CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
648
707
  CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
649
708
  }
650
709
 
651
710
  static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
711
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context;
712
+
713
+ ggml_cuda_set_device(ctx->device);
714
+ CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
715
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
716
+ }
717
+
718
+ static void ggml_backend_cuda_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data,
719
+ size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
720
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context;
721
+
722
+ ggml_cuda_set_device(ctx->device);
723
+ CUDA_CHECK(cudaMemcpy2DAsync(
724
+ (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cudaStreamPerThread));
725
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
726
+ }
727
+
728
+ static void ggml_backend_cuda_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data,
729
+ size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
652
730
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
653
731
 
654
732
  ggml_cuda_set_device(ctx->device);
655
- CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
733
+ CUDA_CHECK(cudaMemcpy2DAsync(
734
+ data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cudaStreamPerThread));
656
735
  CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
657
736
  }
658
737
 
@@ -692,6 +771,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
692
771
  /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor,
693
772
  /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
694
773
  /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
774
+ /* .set_tensor_2d = */ ggml_backend_cuda_buffer_set_tensor_2d,
775
+ /* .get_tensor_2d = */ ggml_backend_cuda_buffer_get_tensor_2d,
695
776
  /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
696
777
  /* .clear = */ ggml_backend_cuda_buffer_clear,
697
778
  /* .reset = */ NULL,
@@ -729,6 +810,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
729
810
 
730
811
  ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
731
812
 
813
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
814
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context;
815
+ std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
816
+ dev_ctx->active_count++;
817
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
818
+
732
819
  return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
733
820
  }
734
821
 
@@ -739,7 +826,11 @@ static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_ty
739
826
  }
740
827
 
741
828
  static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
742
- size_t size = ggml_nbytes(tensor);
829
+ ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *) buft->context;
830
+
831
+ size_t size = tensor->op == GGML_OP_FLASH_ATTN_EXT
832
+ ? ggml_cuda_flash_attn_ext_get_alloc_size(buft_ctx->device, tensor)
833
+ : ggml_nbytes(tensor);
743
834
  int64_t ne0 = tensor->ne[0];
744
835
 
745
836
  if (ggml_is_quantized(tensor->type)) {
@@ -750,8 +841,6 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t
750
841
  }
751
842
 
752
843
  return size;
753
-
754
- GGML_UNUSED(buft);
755
844
  }
756
845
 
757
846
  static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
@@ -1004,6 +1093,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
1004
1093
  /* .memset_tensor = */ NULL,
1005
1094
  /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
1006
1095
  /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor,
1096
+ /* .set_tensor_2d = */ NULL,
1097
+ /* .get_tensor_2d = */ NULL,
1007
1098
  /* .cpy_tensor = */ NULL,
1008
1099
  /* .clear = */ ggml_backend_cuda_split_buffer_clear,
1009
1100
  /* .reset = */ NULL,
@@ -1080,6 +1171,295 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte
1080
1171
  /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
1081
1172
  };
1082
1173
 
1174
+ // Communication context for multi-GPU AllReduce during tensor parallelism.
1175
+ //
1176
+ // Created once per meta backend instance. Resources for the selected mode
1177
+ // (NCCL communicators or the internal AllReduce pipeline) are initialised
1178
+ // eagerly during comm_init so any init failure surfaces at startup rather
1179
+ // than mid-run.
1180
+ struct ggml_backend_cuda_comm_context {
1181
+ using try_allreduce_fn = bool(*)(ggml_backend_cuda_comm_context *, struct ggml_tensor **);
1182
+
1183
+ std::vector<ggml_backend_t> backends;
1184
+ std::vector<int> dev_ids;
1185
+
1186
+ // Set by the init chain (comm_init_{nccl, internal, none}) to one of
1187
+ // try_allreduce_{nccl, internal, butterfly}. nccl needs `comms`,
1188
+ // internal needs `ar_pipeline`, butterfly needs nothing. Per-call
1189
+ // failures return false; the meta backend's generic implementation then
1190
+ // handles that call.
1191
+ try_allreduce_fn try_allreduce = nullptr;
1192
+
1193
+ ggml_cuda_ar_pipeline * ar_pipeline = nullptr;
1194
+
1195
+ #ifdef GGML_USE_NCCL
1196
+ std::vector<ncclComm_t> comms;
1197
+ #endif // GGML_USE_NCCL
1198
+
1199
+ ~ggml_backend_cuda_comm_context() {
1200
+ #ifdef GGML_USE_NCCL
1201
+ for (ncclComm_t comm : comms) {
1202
+ NCCL_CHECK(ncclCommDestroy(comm));
1203
+ }
1204
+ #endif // GGML_USE_NCCL
1205
+ ggml_cuda_ar_pipeline_free(ar_pipeline);
1206
+ }
1207
+ };
1208
+
1209
+ #ifdef GGML_USE_NCCL
1210
+ // AllReduce via NCCL. Reduces as FP32 for small tensors and BF16 for large
1211
+ // tensors (bandwidth-bound), then converts back to FP32.
1212
+ static bool ggml_backend_cuda_comm_allreduce_nccl(
1213
+ ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) {
1214
+ const int64_t ne = ggml_nelements(tensors[0]);
1215
+ // FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0
1216
+ // This then causes a crash in this function
1217
+ if (ne == 0) {
1218
+ return true;
1219
+ }
1220
+
1221
+ const size_t n_backends = comm_ctx->backends.size();
1222
+
1223
+ for (size_t i = 0; i < n_backends; ++i) {
1224
+ GGML_ASSERT(tensors[i] != nullptr);
1225
+ GGML_ASSERT(ggml_nelements(tensors[i]) == ne);
1226
+ GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i]));
1227
+ }
1228
+
1229
+ // For small tensors, simply reduce them as FP32.
1230
+ // The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0.
1231
+ if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) {
1232
+ for (size_t i = 0; i < n_backends; ++i) {
1233
+ if ((tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
1234
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
1235
+ ggml_cuda_set_device(cuda_ctx->device);
1236
+ CUDA_CHECK(cudaMemsetAsync(tensors[i]->data, 0, ggml_nbytes(tensors[i]), cuda_ctx->stream()));
1237
+ }
1238
+ }
1239
+ NCCL_CHECK(ncclGroupStart());
1240
+ for (size_t i = 0; i < n_backends; ++i) {
1241
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
1242
+ NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, comm_ctx->comms[i], cuda_ctx->stream()));
1243
+ }
1244
+ NCCL_CHECK(ncclGroupEnd());
1245
+ return true;
1246
+ }
1247
+
1248
+ // For large tensors it's faster to compress them to BF16 for the reduction:
1249
+ to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32);
1250
+ to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
1251
+
1252
+ ggml_cuda_pool_alloc<nv_bfloat16> tmp[GGML_CUDA_MAX_DEVICES];
1253
+ for (size_t i = 0; i < n_backends; ++i) {
1254
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
1255
+ tmp[i].pool = &cuda_ctx->pool();
1256
+ tmp[i].alloc(ne);
1257
+
1258
+ ggml_cuda_set_device(cuda_ctx->device);
1259
+ if (tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) {
1260
+ to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream());
1261
+ } else {
1262
+ CUDA_CHECK(cudaMemsetAsync(tmp[i].get(), 0, ne * sizeof(nv_bfloat16), cuda_ctx->stream()));
1263
+ }
1264
+ CUDA_CHECK(cudaGetLastError());
1265
+ }
1266
+
1267
+ NCCL_CHECK(ncclGroupStart());
1268
+ for (size_t i = 0; i < n_backends; ++i) {
1269
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
1270
+ NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, comm_ctx->comms[i], cuda_ctx->stream()));
1271
+ }
1272
+ NCCL_CHECK(ncclGroupEnd());
1273
+
1274
+ for (size_t i = 0; i < n_backends; ++i) {
1275
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
1276
+
1277
+ ggml_cuda_set_device(cuda_ctx->device);
1278
+ to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream());
1279
+ CUDA_CHECK(cudaGetLastError());
1280
+ }
1281
+
1282
+ return true;
1283
+ }
1284
+ #endif // GGML_USE_NCCL
1285
+
1286
+ // Run the internal AR pipeline. Returns false on unsupported / failed input
1287
+ // -- the caller decides whether to abort (env-forced) or fall back silently.
1288
+ static bool ggml_backend_cuda_comm_allreduce_internal(
1289
+ ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) {
1290
+ GGML_ASSERT(comm_ctx->ar_pipeline != nullptr);
1291
+
1292
+ const size_t n_backends = comm_ctx->backends.size();
1293
+ GGML_ASSERT(n_backends == 2);
1294
+ GGML_ASSERT(tensors[0] != nullptr);
1295
+
1296
+ const int64_t ne = ggml_nelements(tensors[0]);
1297
+ const ggml_type type = tensors[0]->type;
1298
+
1299
+ if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16 && type != GGML_TYPE_BF16) {
1300
+ GGML_LOG_DEBUG("%s: internal unsupported: type=%d\n", __func__, (int) type);
1301
+ return false;
1302
+ }
1303
+
1304
+ if (ne == 0) {
1305
+ return true;
1306
+ }
1307
+
1308
+ for (size_t i = 0; i < n_backends; ++i) {
1309
+ if (tensors[i] == nullptr) {
1310
+ GGML_LOG_ERROR("%s: internal failed: tensor[%zu] is null\n", __func__, i);
1311
+ return false;
1312
+ }
1313
+ if (ggml_nelements(tensors[i]) != ne || tensors[i]->type != type) {
1314
+ GGML_LOG_ERROR("%s: internal failed: tensor[%zu] ne=%" PRId64 " type=%d expected ne=%" PRId64 " type=%d\n",
1315
+ __func__, i, ggml_nelements(tensors[i]), (int) tensors[i]->type, ne, (int) type);
1316
+ return false;
1317
+ }
1318
+ if (!ggml_is_contiguously_allocated(tensors[i])) {
1319
+ GGML_LOG_DEBUG("%s: internal unsupported: tensor[%zu] is not contiguously allocated: ne=%" PRId64 " nbytes=%zu packed=%zu type=%d\n",
1320
+ __func__, i, ne, ggml_nbytes(tensors[i]),
1321
+ (size_t) ne * ggml_type_size(type) / ggml_blck_size(type), (int) type);
1322
+ return false;
1323
+ }
1324
+ if (((uintptr_t) tensors[i]->data & 0xF) != 0) {
1325
+ GGML_LOG_DEBUG("%s: internal unsupported: tensor[%zu] data pointer is not 16-byte aligned: %p type=%d ne=%" PRId64 "\n",
1326
+ __func__, i, tensors[i]->data, (int) type, ne);
1327
+ return false;
1328
+ }
1329
+ GGML_ASSERT((ggml_nbytes(tensors[i]) & 0xF) == 0);
1330
+ }
1331
+
1332
+ return ggml_cuda_ar_allreduce(comm_ctx->ar_pipeline, comm_ctx->backends.data(), tensors);
1333
+ }
1334
+
1335
+ // ---------------------------------------------------------------------------
1336
+ // Per-call dispatch -- three variants, one per backend. Each is set as
1337
+ // comm_ctx->try_allreduce by the matching init step. Per-call failure
1338
+ // returns false; the meta backend's generic implementation handles that call.
1339
+ // ---------------------------------------------------------------------------
1340
+
1341
+ #ifdef GGML_USE_NCCL
1342
+ static bool ggml_backend_cuda_comm_try_allreduce_nccl(
1343
+ ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) {
1344
+ return ggml_backend_cuda_comm_allreduce_nccl(comm_ctx, tensors);
1345
+ }
1346
+ #endif // GGML_USE_NCCL
1347
+
1348
+ static bool ggml_backend_cuda_comm_try_allreduce_internal(
1349
+ ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) {
1350
+ return ggml_backend_cuda_comm_allreduce_internal(comm_ctx, tensors);
1351
+ }
1352
+
1353
+ static bool ggml_backend_cuda_comm_try_allreduce_butterfly(
1354
+ ggml_backend_cuda_comm_context *, struct ggml_tensor **) {
1355
+ return false;
1356
+ }
1357
+
1358
+ static void ggml_backend_cuda_comm_free(void * comm_ctx_v) {
1359
+ if (comm_ctx_v == nullptr) {
1360
+ return;
1361
+ }
1362
+ delete static_cast<ggml_backend_cuda_comm_context *>(comm_ctx_v);
1363
+ }
1364
+
1365
+ // ---------------------------------------------------------------------------
1366
+ // Init -- chained nccl -> internal -> none. Each step tries to bring up its
1367
+ // resource; on failure it warns and recurses into the next step.
1368
+ // ---------------------------------------------------------------------------
1369
+ static void ggml_backend_cuda_comm_init_none(ggml_backend_cuda_comm_context * ret) {
1370
+ ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_butterfly;
1371
+ }
1372
+
1373
+ static void ggml_backend_cuda_comm_init_internal(ggml_backend_cuda_comm_context * ret) {
1374
+ ret->ar_pipeline = ggml_cuda_ar_pipeline_init(ret->dev_ids.data(), ret->dev_ids.size());
1375
+ if (ret->ar_pipeline) {
1376
+ ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_internal;
1377
+ return;
1378
+ }
1379
+
1380
+ // Clear sticky CUDA error from the failed init.
1381
+ (void) cudaGetLastError();
1382
+ GGML_LOG_WARN("internal AllReduce init failed (n_devices != 2?); "
1383
+ "falling back to meta-backend butterfly\n");
1384
+ ggml_backend_cuda_comm_init_none(ret);
1385
+ }
1386
+
1387
+ static void ggml_backend_cuda_comm_init_nccl(ggml_backend_cuda_comm_context * ret) {
1388
+ #ifdef GGML_USE_NCCL
1389
+ const size_t n = ret->dev_ids.size();
1390
+ ret->comms.resize(n);
1391
+ ncclResult_t rc = ncclCommInitAll(ret->comms.data(), (int) n, ret->dev_ids.data());
1392
+ if (rc == ncclSuccess) {
1393
+ ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_nccl;
1394
+ return;
1395
+ }
1396
+
1397
+ ret->comms.clear();
1398
+ GGML_LOG_WARN("NCCL init failed (%s); falling back to internal AllReduce\n",
1399
+ ncclGetErrorString(rc));
1400
+ #else // GGML_USE_NCCL
1401
+ #ifndef GGML_USE_HIP
1402
+ GGML_LOG_WARN("NCCL not compiled in; falling back to internal AllReduce. "
1403
+ "Recompile with -DGGML_CUDA_NCCL=ON for best multi-GPU performance.\n");
1404
+ #endif // !GGML_USE_HIP
1405
+ #endif // GGML_USE_NCCL
1406
+
1407
+ ggml_backend_cuda_comm_init_internal(ret);
1408
+ }
1409
+
1410
+ // Top-level init. Picks one of the three init paths based on
1411
+ // GGML_CUDA_ALLREDUCE (or the platform default) and lets the chain handle
1412
+ // any fallback. Unrecognised env values warn and fall through to the
1413
+ // platform default.
1414
+ static void * ggml_backend_cuda_comm_init(ggml_backend_t * backends, size_t n_backends) {
1415
+ for (size_t i = 0; i < n_backends; i++) {
1416
+ if (!ggml_backend_is_cuda(backends[i])) {
1417
+ return nullptr;
1418
+ }
1419
+ }
1420
+
1421
+ auto * ret = new ggml_backend_cuda_comm_context;
1422
+ ret->backends.assign(backends, backends + n_backends);
1423
+ ret->dev_ids.reserve(n_backends);
1424
+ for (size_t i = 0; i < n_backends; i++) {
1425
+ ret->dev_ids.push_back(static_cast<ggml_backend_cuda_context *>(backends[i]->context)->device);
1426
+ }
1427
+
1428
+ const char * env = getenv("GGML_CUDA_ALLREDUCE");
1429
+ if (!env) {
1430
+ // Platform default: Linux uses NCCL, otherwise (generally Windows) internal
1431
+ #if defined(__linux__)
1432
+ ggml_backend_cuda_comm_init_nccl(ret);
1433
+ #else
1434
+ ggml_backend_cuda_comm_init_internal(ret);
1435
+ #endif // defined(__linux__)
1436
+ } else {
1437
+ std::string env_str(env);
1438
+ if (env_str == "nccl") {
1439
+ ggml_backend_cuda_comm_init_nccl(ret);
1440
+ } else if (env_str == "internal") {
1441
+ ggml_backend_cuda_comm_init_internal(ret);
1442
+ } else if (env_str == "none") {
1443
+ ggml_backend_cuda_comm_init_none(ret);
1444
+ } else {
1445
+ GGML_LOG_WARN("unknown GGML_CUDA_ALLREDUCE value: %s\n", env);
1446
+ ggml_backend_cuda_comm_init_none(ret);
1447
+ }
1448
+ }
1449
+
1450
+ return ret;
1451
+ }
1452
+
1453
+ // Top-level dispatch -- calls the function pointer chosen by comm_init.
1454
+ // Returns false to let the meta-backend's butterfly run.
1455
+ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) {
1456
+ if (comm_ctx_v == nullptr) {
1457
+ return false;
1458
+ }
1459
+ auto * comm_ctx = static_cast<ggml_backend_cuda_comm_context *>(comm_ctx_v);
1460
+ return comm_ctx->try_allreduce(comm_ctx, tensors);
1461
+ }
1462
+
1083
1463
  ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) {
1084
1464
  static std::mutex mutex;
1085
1465
  std::lock_guard<std::mutex> lock(mutex);
@@ -1135,6 +1515,12 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
1135
1515
  }
1136
1516
 
1137
1517
  static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1518
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1519
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context;
1520
+ std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
1521
+ dev_ctx->active_count--;
1522
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1523
+
1138
1524
  CUDA_CHECK(cudaFreeHost(buffer->context));
1139
1525
  }
1140
1526
 
@@ -1143,6 +1529,8 @@ static void * ggml_cuda_host_malloc(size_t size) {
1143
1529
  return nullptr;
1144
1530
  }
1145
1531
 
1532
+ ggml_cuda_set_device(0); // cudaMallocHost can create the implicit CUDA device context, make sure that this is consistently done on device 0.
1533
+
1146
1534
  void * ptr = nullptr;
1147
1535
  cudaError_t err = cudaMallocHost((void **) &ptr, size);
1148
1536
  if (err != cudaSuccess) {
@@ -1168,6 +1556,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm
1168
1556
  buffer->buft = buft;
1169
1557
  buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
1170
1558
 
1559
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1560
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context;
1561
+ std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
1562
+ dev_ctx->active_count++;
1563
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1564
+
1171
1565
  return buffer;
1172
1566
  }
1173
1567
 
@@ -1297,7 +1691,12 @@ static void ggml_cuda_op_mul_mat_cublas(
1297
1691
  const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
1298
1692
  (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
1299
1693
 
1300
- const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
1694
+ const bool use_fp16 =
1695
+ src0->type != GGML_TYPE_NVFP4 &&
1696
+ (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
1697
+ ggml_is_contiguous(src0) &&
1698
+ row_diff == src0->ne[1] &&
1699
+ dst->op_params[0] == GGML_PREC_DEFAULT;
1301
1700
 
1302
1701
  if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1303
1702
  ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
@@ -1421,64 +1820,6 @@ static void ggml_cuda_op_mul_mat_cublas(
1421
1820
  GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size);
1422
1821
  }
1423
1822
 
1424
- static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
1425
- static bool peer_access_enabled = false;
1426
-
1427
- const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
1428
-
1429
- if (peer_access_enabled == enable_peer_access) {
1430
- return;
1431
- }
1432
-
1433
- #ifdef NDEBUG
1434
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
1435
- ggml_cuda_set_device(id);
1436
- CUDA_CHECK(cudaDeviceSynchronize());
1437
- }
1438
-
1439
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
1440
- ggml_cuda_set_device(id);
1441
-
1442
- for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) {
1443
- if (id == id_other) {
1444
- continue;
1445
- }
1446
- if (id != main_device && id_other != main_device) {
1447
- continue;
1448
- }
1449
-
1450
- int can_access_peer;
1451
- CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
1452
- if (can_access_peer) {
1453
- if (enable_peer_access) {
1454
- cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);
1455
- if (err != cudaErrorPeerAccessAlreadyEnabled) {
1456
- CUDA_CHECK(err);
1457
- } else {
1458
- // reset the error
1459
- (void)cudaGetLastError();
1460
- }
1461
- } else {
1462
- cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
1463
- if (err != cudaErrorPeerAccessNotEnabled) {
1464
- CUDA_CHECK(err);
1465
- } else {
1466
- // reset the error
1467
- (void)cudaGetLastError();
1468
- }
1469
- }
1470
- }
1471
- }
1472
- }
1473
-
1474
- ggml_cuda_set_device(main_device);
1475
- #endif // NDEBUG
1476
-
1477
- peer_access_enabled = enable_peer_access;
1478
-
1479
- GGML_UNUSED(main_device);
1480
- }
1481
-
1482
1823
  static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
1483
1824
  void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
1484
1825
 
@@ -2270,6 +2611,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2270
2611
  use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
2271
2612
  use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
2272
2613
  use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
2614
+ use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]);
2273
2615
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2274
2616
  }
2275
2617
  } else {
@@ -2278,6 +2620,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2278
2620
  use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
2279
2621
  use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
2280
2622
  use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
2623
+ use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]);
2281
2624
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2282
2625
  }
2283
2626
 
@@ -2295,6 +2638,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2295
2638
  bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2296
2639
  bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2297
2640
 
2641
+ const int32_t hint = ggml_get_op_params_i32(dst, 1);
2642
+ if (hint == GGML_HINT_SRC0_IS_HADAMARD && !split && ggml_cuda_op_fwht(ctx, src1, dst)) {
2643
+ return;
2644
+ }
2645
+
2298
2646
  if (!split && use_mul_mat_vec_f) {
2299
2647
  // the custom F16 vector kernel can be used over batched cuBLAS GEMM
2300
2648
  // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
@@ -2338,7 +2686,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2338
2686
  static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
2339
2687
  if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
2340
2688
  if (ggml_is_quantized(src0->type)) {
2341
- if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
2689
+ const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc);
2690
+ if (ne2 <= mmvq_mmid_max) {
2342
2691
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
2343
2692
  return;
2344
2693
  }
@@ -2478,11 +2827,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2478
2827
  }
2479
2828
 
2480
2829
  static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
2481
- // why is this here instead of mul_mat?
2482
- if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) {
2483
- ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
2484
- }
2485
-
2486
2830
  switch (dst->op) {
2487
2831
  case GGML_OP_ARGMAX:
2488
2832
  ggml_cuda_argmax(ctx, dst);
@@ -2835,26 +3179,54 @@ static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) {
2835
3179
  static void ggml_backend_cuda_free(ggml_backend_t backend) {
2836
3180
  ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
2837
3181
 
3182
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
3183
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context;
3184
+ std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
3185
+ dev_ctx->active_count--;
3186
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
3187
+
2838
3188
  delete cuda_ctx;
2839
3189
  delete backend;
2840
3190
  }
2841
3191
 
2842
3192
  static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2843
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
3193
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
2844
3194
  ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
2845
3195
 
2846
3196
  GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
2847
3197
 
2848
- CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
3198
+ CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
2849
3199
  }
2850
3200
 
2851
3201
  static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2852
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
3202
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
3203
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
3204
+
3205
+ GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
3206
+
3207
+ CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
3208
+ }
3209
+
3210
+ static void ggml_backend_cuda_set_tensor_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data,
3211
+ size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
3212
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
3213
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
3214
+
3215
+ GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
3216
+
3217
+ CUDA_CHECK(cudaMemcpy2DAsync(
3218
+ (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cuda_ctx->stream()));
3219
+ }
3220
+
3221
+ static void ggml_backend_cuda_get_tensor_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data,
3222
+ size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
3223
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
2853
3224
  ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
2854
3225
 
2855
3226
  GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
2856
3227
 
2857
- CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
3228
+ CUDA_CHECK(cudaMemcpy2DAsync(
3229
+ data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
2858
3230
  }
2859
3231
 
2860
3232
  static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
@@ -2865,21 +3237,21 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
2865
3237
  return false;
2866
3238
  }
2867
3239
 
2868
- if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) {
3240
+ if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) {
2869
3241
  return false;
2870
3242
  }
2871
3243
 
2872
3244
  // device -> device copy
2873
- ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
2874
- ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
3245
+ ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *) backend_src->context;
3246
+ ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *) backend_dst->context;
2875
3247
 
2876
- ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
2877
- ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
3248
+ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context;
3249
+ ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context;
2878
3250
 
2879
3251
  if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
2880
3252
  #ifndef NDEBUG
2881
3253
  GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
2882
- #endif
3254
+ #endif // NDEBUG
2883
3255
  return false;
2884
3256
  }
2885
3257
 
@@ -2892,7 +3264,7 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
2892
3264
  return false;
2893
3265
  #else
2894
3266
  CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
2895
- #endif
3267
+ #endif // GGML_CUDA_NO_PEER_COPY
2896
3268
  }
2897
3269
 
2898
3270
  // record event on src stream after the copy
@@ -2941,14 +3313,18 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
2941
3313
  }
2942
3314
 
2943
3315
  // [TAG_MUL_MAT_ID_CUDA_GRAPHS]
2944
- if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
2945
- // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
2946
- // TODO: figure out a way to enable for larger batch sizes, without hurting performance
2947
- // ref: https://github.com/ggml-org/llama.cpp/pull/18958
2948
- use_cuda_graph = false;
3316
+ if (node->op == GGML_OP_MUL_MAT_ID) {
3317
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
3318
+ const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc);
3319
+ if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) {
3320
+ // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
3321
+ // TODO: figure out a way to enable for larger batch sizes, without hurting performance
3322
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18958
3323
+ use_cuda_graph = false;
2949
3324
  #ifndef NDEBUG
2950
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
3325
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
2951
3326
  #endif
3327
+ }
2952
3328
  }
2953
3329
 
2954
3330
  if (!use_cuda_graph) {
@@ -2959,135 +3335,51 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
2959
3335
  return use_cuda_graph;
2960
3336
  }
2961
3337
 
2962
- static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
2963
- memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
2964
- props->node_data = node->data;
2965
- props->node_op = node->op;
2966
- props->node_type = node->type;
2967
- props->flags = node->flags;
2968
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
2969
- props->ne[i] = node->ne[i];
2970
- props->nb[i] = node->nb[i];
2971
- }
2972
- for (int i = 0; i < GGML_MAX_SRC; i++) {
2973
- if (!node->src[i]) {
2974
- continue;
2975
- }
2976
-
2977
- props->src_data[i] = node->src[i]->data;
2978
- }
2979
- memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
3338
+ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
3339
+ return cgraph->nodes[0];
2980
3340
  }
2981
3341
 
2982
- static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
2983
- if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
2984
- return false;
2985
- }
3342
+ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
3343
+ bool res = false;
2986
3344
 
2987
- if (node->op != props->node_op) {
2988
- return false;
2989
- }
3345
+ const void * graph_key = ggml_cuda_graph_get_key(cgraph);
3346
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
2990
3347
 
2991
- if (node->type != props->node_type) {
3348
+ if (cgraph->uid != 0 &&
3349
+ cgraph->uid == graph->uid) {
3350
+ GGML_LOG_DEBUG("CUDA Graph id %zu reused\n", cgraph->uid);
3351
+ GGML_ASSERT((int)graph->node_props.size() == cgraph->n_nodes);
2992
3352
  return false;
2993
3353
  }
2994
3354
 
2995
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
2996
- if (node->ne[i] != props->ne[i]) {
2997
- return false;
2998
- }
2999
- if (node->nb[i] != props->nb[i]) {
3000
- return false;
3001
- }
3355
+ graph->uid = cgraph->uid;
3356
+
3357
+ // Check if the graph size has changed
3358
+ if ((int)graph->node_props.size() != cgraph->n_nodes) {
3359
+ res = true;
3360
+ graph->node_props.resize(cgraph->n_nodes);
3002
3361
  }
3003
3362
 
3004
- if (node->op != GGML_OP_VIEW) {
3005
- for (int i = 0; i < GGML_MAX_SRC; i++) {
3006
- if (!node->src[i]) {
3007
- if (props->src_data[i] != nullptr) {
3008
- return false;
3009
- }
3010
- continue;
3363
+ for (int i = 0; i < cgraph->n_nodes; i++) {
3364
+ ggml_cuda_graph::node_properties prop = {};
3365
+ memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor));
3366
+
3367
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
3368
+ if (cgraph->nodes[i]->src[j]) {
3369
+ prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j]->data;
3370
+ memcpy(prop.node_src_ne[j], cgraph->nodes[i]->src[j]->ne, sizeof(prop.node_src_ne[j]));
3371
+ memcpy(prop.node_src_nb[j], cgraph->nodes[i]->src[j]->nb, sizeof(prop.node_src_nb[j]));
3011
3372
  }
3373
+ }
3012
3374
 
3013
- if (node->src[i]->data != props->src_data[i]) {
3014
- return false;
3015
- }
3375
+ if (res || memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
3376
+ graph->node_props[i] = prop;
3377
+ res = true;
3016
3378
  }
3017
3379
  }
3018
3380
 
3019
- if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
3020
- return false;
3021
- }
3022
-
3023
- if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) {
3024
- return false;
3025
- }
3026
-
3027
- return true;
3028
- }
3029
-
3030
- static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
3031
- return cgraph->nodes[0];
3032
- }
3033
-
3034
- static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
3035
- bool res = false;
3036
-
3037
- const void * graph_key = ggml_cuda_graph_get_key(cgraph);
3038
- ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3039
-
3040
- // Check if the graph size has changed
3041
- if (graph->props.size() != (size_t)cgraph->n_nodes) {
3042
- res = true;
3043
- graph->props.resize(cgraph->n_nodes);
3044
- }
3045
-
3046
- // Loop over nodes in GGML graph to determine if CUDA graph update is required
3047
- // and store properties to allow this comparison for the next token
3048
- std::unordered_set<ggml_tensor *> seen_node;
3049
- std::vector<ggml_tensor *> srcs_extra;
3050
- for (int i = 0; i < cgraph->n_nodes; i++) {
3051
- bool props_match = true;
3052
-
3053
- seen_node.insert(cgraph->nodes[i]);
3054
-
3055
- if (!res) {
3056
- props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
3057
- }
3058
- if (!props_match) {
3059
- res = true;
3060
- }
3061
- ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
3062
-
3063
- for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3064
- ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
3065
- if (src && seen_node.find(src) == seen_node.end()) {
3066
- srcs_extra.push_back(src);
3067
- }
3068
- }
3069
- }
3070
-
3071
- if (graph->extra.size() != (size_t) srcs_extra.size()) {
3072
- res = true;
3073
- graph->extra.resize(srcs_extra.size());
3074
- }
3075
-
3076
- for (size_t i = 0; i < srcs_extra.size(); ++i) {
3077
- bool props_match = true;
3078
-
3079
- if (!res) {
3080
- props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
3081
- }
3082
-
3083
- if (!props_match) {
3084
- res = true;
3085
- }
3086
- ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
3087
- }
3088
-
3089
- return res;
3090
- }
3381
+ return res;
3382
+ }
3091
3383
 
3092
3384
  static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
3093
3385
  ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
@@ -3298,6 +3590,71 @@ static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int nod
3298
3590
  return true;
3299
3591
  }
3300
3592
 
3593
+ // returns whether the write (out) nodes overwrite the read nodes in operation
3594
+ static bool ggml_cuda_check_fusion_memory_ranges(const ggml_cgraph * cgraph,
3595
+ const int node_idx,
3596
+ const int node_count,
3597
+ const int * out_nodes,
3598
+ const int out_count,
3599
+ const bool is_topk_moe = false) {
3600
+ auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
3601
+ const int64_t a_start = (int64_t) a->data;
3602
+ const int64_t a_end = a_start + ggml_backend_buft_get_alloc_size(a->buffer->buft, a);
3603
+
3604
+ const int64_t b_start = (int64_t) b->data;
3605
+ const int64_t b_end = b_start + ggml_backend_buft_get_alloc_size(b->buffer->buft, b);
3606
+
3607
+ if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
3608
+ return true;
3609
+ }
3610
+
3611
+ return false;
3612
+ };
3613
+
3614
+ bool is_ok = true;
3615
+ // exception for topk-moe, as each row is read entirely before writing
3616
+ if (ggml_nrows(cgraph->nodes[node_idx]) == 1 && is_topk_moe) {
3617
+ return true;
3618
+ }
3619
+
3620
+ for (int i = 0; i < out_count; ++i) {
3621
+ const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
3622
+
3623
+ for (int j = node_idx; j < node_idx + node_count; ++j) {
3624
+ // Loop over all srcs of all nodes in the fusion. If the src overlaps
3625
+ // the destination and the src is not an intermediate node that's being
3626
+ // elided, then disable fusion.
3627
+
3628
+ for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3629
+ const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
3630
+
3631
+ if (!src || src->op == GGML_OP_NONE) {
3632
+ continue;
3633
+ }
3634
+
3635
+ if (nodes_overlap(dst, src)) {
3636
+ bool found = false;
3637
+
3638
+ for (int k = node_idx; k < j; ++k) {
3639
+ if (cgraph->nodes[k] == src) {
3640
+ found = true;
3641
+ break;
3642
+ }
3643
+ }
3644
+
3645
+ if (!found) {
3646
+ is_ok = false;
3647
+ break;
3648
+ }
3649
+ }
3650
+ }
3651
+ }
3652
+ }
3653
+
3654
+ return is_ok;
3655
+ }
3656
+
3657
+
3301
3658
  static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
3302
3659
  int node_idx,
3303
3660
  std::initializer_list<enum ggml_op> ops,
@@ -3327,7 +3684,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
3327
3684
  const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
3328
3685
 
3329
3686
  if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
3330
- return true;
3687
+ int out_nodes[] = { node_idx + 4 };
3688
+ return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1);
3331
3689
  }
3332
3690
  }
3333
3691
 
@@ -3338,7 +3696,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
3338
3696
  const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
3339
3697
 
3340
3698
  if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
3341
- return true;
3699
+ int out_nodes[] = { node_idx + 2 };
3700
+ return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1);
3342
3701
  }
3343
3702
  }
3344
3703
 
@@ -3404,6 +3763,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
3404
3763
  && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
3405
3764
  const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
3406
3765
  const ggml_tensor * silu = cgraph->nodes[node_idx+1];
3766
+ if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) {
3767
+ return false;
3768
+ }
3407
3769
 
3408
3770
  if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
3409
3771
  return false;
@@ -3412,6 +3774,31 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
3412
3774
  return true;
3413
3775
  }
3414
3776
 
3777
+ if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_ADD
3778
+ && ops.begin()[2] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
3779
+ const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
3780
+ const ggml_tensor * add = cgraph->nodes[node_idx+1];
3781
+ const ggml_tensor * silu = cgraph->nodes[node_idx+2];
3782
+ if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) {
3783
+ return false;
3784
+ }
3785
+
3786
+ if (ssm_conv->type != GGML_TYPE_F32 || add->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
3787
+ return false;
3788
+ }
3789
+
3790
+ // ADD must consume ssm_conv's output and broadcast a 1-D channel-wise bias.
3791
+ const ggml_tensor * bias = (add->src[0] == ssm_conv) ? add->src[1] : add->src[0];
3792
+ if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) {
3793
+ return false;
3794
+ }
3795
+ if (ggml_nelements(bias) != ssm_conv->ne[0] || bias->ne[0] != ssm_conv->ne[0]) {
3796
+ return false;
3797
+ }
3798
+
3799
+ return true;
3800
+ }
3801
+
3415
3802
  if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
3416
3803
  && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
3417
3804
  const ggml_tensor * unary = cgraph->nodes[node_idx];
@@ -3440,6 +3827,30 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
3440
3827
  return true;
3441
3828
  }
3442
3829
 
3830
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_SQR
3831
+ && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_RELU) {
3832
+ const ggml_tensor * unary = cgraph->nodes[node_idx];
3833
+ const ggml_tensor * sqr = cgraph->nodes[node_idx+1];
3834
+
3835
+ if (ggml_get_unary_op(unary) != GGML_UNARY_OP_RELU) {
3836
+ return false;
3837
+ }
3838
+
3839
+ if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
3840
+ return false;
3841
+ }
3842
+
3843
+ if (unary->type != sqr->type) {
3844
+ return false;
3845
+ }
3846
+
3847
+ if (!ggml_is_contiguous(unary->src[0])) {
3848
+ return false;
3849
+ }
3850
+
3851
+ return true;
3852
+ }
3853
+
3443
3854
  if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
3444
3855
  && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
3445
3856
  const ggml_tensor *scale = cgraph->nodes[node_idx];
@@ -3464,67 +3875,404 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
3464
3875
  return false;
3465
3876
  }
3466
3877
 
3467
- // returns whether the write (out) nodes overwrite the read nodes in operation
3468
- static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph,
3469
- int node_idx,
3470
- int node_count,
3471
- int * out_nodes,
3472
- int out_count) {
3473
- auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
3474
- const int64_t a_start = (int64_t) a->data;
3475
- const int64_t a_end = a_start + ggml_nbytes(a);
3878
+ // try and fuse nodes and return the number of nodes to skip
3879
+ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, int i) {
3880
+
3881
+ static bool disable_fusion = getenv("GGML_CUDA_DISABLE_FUSION") != nullptr && std::atoi(getenv("GGML_CUDA_DISABLE_FUSION"));
3882
+ if (disable_fusion) {
3883
+ return 0;
3884
+ }
3885
+
3886
+ ggml_tensor * node = cgraph->nodes[i];
3887
+
3888
+ //topk-moe
3889
+ if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
3890
+ cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
3891
+ ggml_cuda_topk_moe_args args;
3892
+ const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
3893
+ std::vector<ggml_op> ops;
3894
+
3895
+ if (can_fuse) {
3896
+ const ggml_tensor * logits = node->src[0];
3897
+ ggml_tensor * weights = nullptr;
3898
+ ggml_tensor * ids = nullptr;
3899
+ const ggml_tensor * bias = nullptr;
3900
+ const ggml_tensor * clamp = nullptr;
3901
+ const ggml_tensor * scale = nullptr;
3902
+
3903
+ if (!args.delayed_softmax) {
3904
+ ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
3905
+ int out_nodes[2]; // nodes which can't be elided
3906
+
3907
+ if (args.prob_bias) {
3908
+ bias = cgraph->nodes[i + 2]->src[1];
3909
+ ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, GGML_OP_VIEW,
3910
+ GGML_OP_GET_ROWS });
3911
+ out_nodes[0] = i + 4;
3912
+ ids = cgraph->nodes[i + 4];
3913
+ } else {
3914
+ ops.insert(ops.end(),
3915
+ { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS });
3916
+ out_nodes[0] = i + 3;
3917
+ ids = cgraph->nodes[i + 3];
3918
+ }
3476
3919
 
3477
- const int64_t b_start = (int64_t) b->data;
3478
- const int64_t b_end = b_start + ggml_nbytes(b);
3920
+ if (args.norm) {
3921
+ ops.insert(ops.end(),
3922
+ { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_RESHAPE });
3923
+ clamp = cgraph->nodes[i + ops.size() - 3];
3924
+ }
3925
+ if (args.scale) {
3926
+ ops.insert(ops.end(), { GGML_OP_SCALE });
3927
+ scale = cgraph->nodes[i + ops.size() - 1];
3928
+ }
3479
3929
 
3480
- if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
3481
- return true;
3930
+ weights = cgraph->nodes[i + ops.size() - 1];
3931
+ out_nodes[1] = i + ops.size() - 1;
3932
+
3933
+ if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
3934
+ ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
3935
+ ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) {
3936
+ ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
3937
+ return ops.size() - 1;
3938
+ }
3939
+ } else if (!args.norm && !args.prob_bias) {
3940
+ //special case gpt-oss, no norm, no bias.
3941
+ ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
3942
+ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
3943
+ weights = cgraph->nodes[i + 5];
3944
+ ids = cgraph->nodes[i + 1];
3945
+ const ggml_tensor * softmax = cgraph->nodes[i + 4];
3946
+
3947
+ int out_nodes[2] = { i + 1, i + 5 };
3948
+ if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
3949
+ ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
3950
+ ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) {
3951
+ ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
3952
+ return ops.size() - 1;
3953
+ }
3954
+ }
3482
3955
  }
3956
+ }
3483
3957
 
3484
- return false;
3485
- };
3958
+ //RoPE + view + set-rows
3959
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3960
+ ggml_tensor * rope = cgraph->nodes[i];
3961
+ ggml_tensor * set_rows = cgraph->nodes[i + 2];
3486
3962
 
3487
- bool is_ok = true;
3488
- // for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok
3489
- if (ggml_nrows(cgraph->nodes[node_idx]) == 1) {
3490
- return true;
3963
+ ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
3964
+ return 2;
3491
3965
  }
3492
3966
 
3493
- for (int i = 0; i < out_count; ++i) {
3494
- const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
3967
+ // Snake activation: y = x + sin(a*x)^2 * inv_b
3968
+ // Naive 5-op decomposition emitted by frontends: mul -> sin -> sqr -> mul -> add
3969
+ if (ggml_can_fuse_subgraph(cgraph, i,
3970
+ { GGML_OP_MUL, GGML_OP_SIN, GGML_OP_SQR, GGML_OP_MUL, GGML_OP_ADD },
3971
+ { i + 4 })) {
3972
+ const ggml_tensor * mul0 = cgraph->nodes[i];
3973
+ const ggml_tensor * sqr = cgraph->nodes[i + 2];
3974
+ const ggml_tensor * mul1 = cgraph->nodes[i + 3];
3975
+ ggml_tensor * add = cgraph->nodes[i + 4];
3495
3976
 
3496
- for (int j = node_idx; j < node_idx + node_count; ++j) {
3497
- // Loop over all srcs of all nodes in the fusion. If the src overlaps
3498
- // the destination and the src is not an intermediate node that's being
3499
- // elided, then disable fusion.
3977
+ // x carries the full activation shape, a is the broadcast operand
3978
+ const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
3979
+ const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];
3500
3980
 
3501
- for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3502
- const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
3981
+ // mul1 reads sqr and inv_b in either operand order
3982
+ const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];
3503
3983
 
3504
- if (!src || src->op == GGML_OP_NONE) {
3505
- continue;
3506
- }
3984
+ // closure check: the trailing add must read the same x as the leading mul
3985
+ const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
3507
3986
 
3508
- if (nodes_overlap(dst, src)) {
3509
- bool found = false;
3987
+ // Kernel iterates over total = T * C, so x and add must be 2D and
3988
+ // a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled.
3989
+ const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) &&
3990
+ (add->ne[2] == 1 && add->ne[3] == 1) &&
3991
+ (a->ne[2] == 1 && a->ne[3] == 1);
3992
+ const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1];
3510
3993
 
3511
- for (int k = node_idx; k < j; ++k) {
3512
- if (cgraph->nodes[k] == src) {
3513
- found = true;
3514
- break;
3515
- }
3516
- }
3994
+ // x must be in the supported whitelist and every operand / intermediate
3995
+ // result must share x's type, since launch_snake casts a / inv_b as
3996
+ // float and templates the kernel on a single T. Mixed precision chains
3997
+ // fall back to the naive path.
3998
+ const ggml_tensor * sin1 = cgraph->nodes[i + 1];
3999
+ const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) &&
4000
+ (a->type == x->type) && (inv_b->type == x->type) &&
4001
+ (mul0->type == x->type) && (sin1->type == x->type) &&
4002
+ (sqr->type == x->type) && (mul1->type == x->type) &&
4003
+ (add->type == x->type);
3517
4004
 
3518
- if (!found) {
3519
- is_ok = false;
3520
- break;
4005
+ if (types_ok && shape_ok && dim_ok && x_in_add == x) {
4006
+ ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add);
4007
+ return 4;
4008
+ }
4009
+ }
4010
+
4011
+ // multi-(add or mul)
4012
+ if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
4013
+ int n_fuse = 0;
4014
+ ggml_op ops[8];
4015
+ std::fill(ops, ops + 8, node->op);
4016
+
4017
+ for (; n_fuse <= 6; ++n_fuse) {
4018
+ if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
4019
+ break;
4020
+ }
4021
+ if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
4022
+ break;
4023
+ }
4024
+ if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
4025
+ break;
4026
+ }
4027
+ }
4028
+
4029
+ n_fuse++;
4030
+
4031
+ if (n_fuse > 1) {
4032
+ ggml_tensor fused_node;
4033
+ memcpy(&fused_node, node, sizeof(ggml_tensor));
4034
+ for (int j = 0; j < n_fuse - 1; ++j) {
4035
+ fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
4036
+ }
4037
+ fused_node.data = cgraph->nodes[i + n_fuse - 1]->data;
4038
+ if (node->op == GGML_OP_ADD) {
4039
+ ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse);
4040
+ } else {
4041
+ ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse);
4042
+ }
4043
+ return n_fuse - 1;
4044
+ }
4045
+ }
4046
+
4047
+ bool fused_mul_mat_vec = false;
4048
+ int fused_node_count = 0;
4049
+
4050
+ // gate + glu + up
4051
+ for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
4052
+ const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
4053
+
4054
+ if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
4055
+ ggml_tensor * glu = cgraph->nodes[i + 4];
4056
+ ggml_tensor * gate_bias_n = glu->src[0];
4057
+ ggml_tensor * up_bias_n = glu->src[1];
4058
+
4059
+ //we don't assume the order for {gate, up}. Instead infer it from the bias tensor
4060
+ ggml_tensor * gate_n = nullptr;
4061
+ ggml_tensor * up_n = nullptr;
4062
+
4063
+ if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
4064
+ gate_n = cgraph->nodes[i];
4065
+ up_n = cgraph->nodes[i + 2];
4066
+ } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
4067
+ gate_n = cgraph->nodes[i + 2];
4068
+ up_n = cgraph->nodes[i];
4069
+ } else {
4070
+ continue;
4071
+ }
4072
+
4073
+ auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
4074
+ if (op_bias == GGML_OP_ADD) {
4075
+ if (bias_node->src[0] == mul_node) {
4076
+ return bias_node->src[1];
4077
+ }
4078
+ if (bias_node->src[1] == mul_node) {
4079
+ return bias_node->src[0];
3521
4080
  }
4081
+ return (ggml_tensor *) nullptr;
3522
4082
  }
4083
+ GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
4084
+ GGML_ASSERT(bias_node->src[0] == mul_node);
4085
+ return bias_node->src[1];
4086
+ };
4087
+
4088
+ ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
4089
+ ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
4090
+
4091
+ if (!up_bias_tensor || !gate_bias_tensor) {
4092
+ continue;
4093
+ }
4094
+
4095
+ // we don't support repeating adds
4096
+ if (bias_op == GGML_OP_ADD && (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
4097
+ !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
4098
+ continue;
4099
+ }
4100
+
4101
+ const ggml_tensor * src0 = up_n->src[0];
4102
+ const ggml_tensor * src1 = up_n->src[1];
4103
+ const ggml_tensor * ids = up_n->src[2];
4104
+
4105
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
4106
+ ggml_cuda_mm_fusion_args_host fusion_data{};
4107
+ fusion_data.gate = gate_n->src[0];
4108
+ fusion_data.x_bias = up_bias_tensor;
4109
+ fusion_data.gate_bias = gate_bias_tensor;
4110
+ fusion_data.glu_op = ggml_get_glu_op(glu);
4111
+
4112
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
4113
+ fused_mul_mat_vec = true;
4114
+ fused_node_count = 5;
4115
+ break;
4116
+ }
4117
+
4118
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
4119
+ ggml_cuda_mm_fusion_args_host fusion_data{};
4120
+ fusion_data.gate = gate_n->src[0];
4121
+ fusion_data.x_bias = up_bias_tensor;
4122
+ fusion_data.gate_bias = gate_bias_tensor;
4123
+ fusion_data.glu_op = ggml_get_glu_op(glu);
4124
+
4125
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
4126
+ fused_mul_mat_vec = true;
4127
+ fused_node_count = 5;
4128
+ break;
4129
+ }
4130
+ } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
4131
+ ggml_tensor * glu = cgraph->nodes[i + 2];
4132
+ ggml_tensor * gate = glu->src[0];
4133
+ ggml_tensor * up = glu->src[1];
4134
+
4135
+ bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) ||
4136
+ (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
4137
+
4138
+ if (!ok) {
4139
+ continue;
4140
+ }
4141
+
4142
+ const ggml_tensor * src0 = up->src[0];
4143
+ const ggml_tensor * src1 = up->src[1];
4144
+ const ggml_tensor * ids = up->src[2];
4145
+
4146
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
4147
+ ggml_cuda_mm_fusion_args_host fusion_data{};
4148
+ fusion_data.gate = gate->src[0];
4149
+ fusion_data.glu_op = ggml_get_glu_op(glu);
4150
+
4151
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
4152
+ fused_mul_mat_vec = true;
4153
+ fused_node_count = 3;
4154
+ break;
4155
+ }
4156
+
4157
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
4158
+ ggml_cuda_mm_fusion_args_host fusion_data{};
4159
+ fusion_data.gate = gate->src[0];
4160
+ fusion_data.glu_op = ggml_get_glu_op(glu);
4161
+
4162
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
4163
+ fused_mul_mat_vec = true;
4164
+ fused_node_count = 3;
4165
+ break;
3523
4166
  }
3524
4167
  }
3525
4168
  }
3526
4169
 
3527
- return is_ok;
4170
+ if (fused_mul_mat_vec) {
4171
+ return fused_node_count - 1;
4172
+ }
4173
+
4174
+ fused_mul_mat_vec = false;
4175
+ fused_node_count = 0;
4176
+
4177
+ // gate + add + glu + up + add
4178
+ for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
4179
+ const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
4180
+
4181
+ if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
4182
+ continue;
4183
+ }
4184
+
4185
+ ggml_tensor * mm_node = cgraph->nodes[i];
4186
+ ggml_tensor * bias_node = cgraph->nodes[i + 1];
4187
+
4188
+ ggml_tensor * bias_tensor = nullptr;
4189
+ if (bias_op == GGML_OP_ADD) {
4190
+ if (bias_node->src[0] == mm_node) {
4191
+ bias_tensor = bias_node->src[1];
4192
+ } else if (bias_node->src[1] == mm_node) {
4193
+ bias_tensor = bias_node->src[0];
4194
+ } else {
4195
+ continue;
4196
+ }
4197
+ } else {
4198
+ if (bias_node->src[0] != mm_node) {
4199
+ continue;
4200
+ }
4201
+ bias_tensor = bias_node->src[1];
4202
+ }
4203
+
4204
+ const ggml_tensor * src0 = mm_node->src[0];
4205
+ const ggml_tensor * src1 = mm_node->src[1];
4206
+ const ggml_tensor * ids = mm_node->src[2];
4207
+
4208
+ if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
4209
+ continue;
4210
+ }
4211
+
4212
+ if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
4213
+ continue;
4214
+ }
4215
+
4216
+ ggml_cuda_mm_fusion_args_host fusion_data{};
4217
+ fusion_data.x_bias = bias_tensor;
4218
+
4219
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
4220
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
4221
+ fused_mul_mat_vec = true;
4222
+ fused_node_count = 2;
4223
+ break;
4224
+ }
4225
+
4226
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
4227
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
4228
+ fused_mul_mat_vec = true;
4229
+ fused_node_count = 2;
4230
+ break;
4231
+ }
4232
+ }
4233
+
4234
+ if (fused_mul_mat_vec) {
4235
+ return fused_node_count - 1;
4236
+ }
4237
+
4238
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD }, {})) {
4239
+ ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]);
4240
+ return 2;
4241
+ }
4242
+
4243
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
4244
+ ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i + 1]);
4245
+ return 1;
4246
+ }
4247
+
4248
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
4249
+ ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]);
4250
+ return 2;
4251
+ }
4252
+
4253
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
4254
+ ggml_cuda_op_ssm_conv(*cuda_ctx, node, /*bias_add_node=*/ nullptr, cgraph->nodes[i + 1]);
4255
+ return 1;
4256
+ }
4257
+
4258
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
4259
+ ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
4260
+ ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
4261
+ ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i + 1]);
4262
+ return 1;
4263
+ }
4264
+
4265
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) {
4266
+ ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i + 1]);
4267
+ return 1;
4268
+ }
4269
+
4270
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
4271
+ ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i + 2], node);
4272
+ return 2;
4273
+ }
4274
+
4275
+ return 0;
3528
4276
  }
3529
4277
 
3530
4278
  static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
@@ -3673,345 +4421,11 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
3673
4421
  continue;
3674
4422
  }
3675
4423
 
3676
- // start of fusion operations
3677
- static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
3678
- if (!disable_fusion) {
3679
- ggml_cuda_topk_moe_args args;
3680
-
3681
- if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
3682
- cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
3683
- const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
3684
-
3685
- std::vector<ggml_op> ops;
3686
-
3687
- if (can_fuse) {
3688
- const ggml_tensor * logits = node->src[0];
3689
- ggml_tensor * weights = nullptr;
3690
- ggml_tensor * ids = nullptr;
3691
- const ggml_tensor * bias = nullptr;
3692
- const ggml_tensor * clamp = nullptr;
3693
- const ggml_tensor * scale = nullptr;
3694
-
3695
- if (!args.delayed_softmax) {
3696
- ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
3697
- int out_nodes[2]; // nodes which can't be elided
3698
-
3699
- if (args.prob_bias) {
3700
- bias = cgraph->nodes[i + 2]->src[1];
3701
- ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
3702
- GGML_OP_VIEW, GGML_OP_GET_ROWS });
3703
- out_nodes[0] = i + 4;
3704
- ids = cgraph->nodes[i + 4];
3705
- } else {
3706
- ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
3707
- GGML_OP_GET_ROWS });
3708
- out_nodes[0] = i + 3;
3709
- ids = cgraph->nodes[i + 3];
3710
- }
3711
-
3712
- if (args.norm) {
3713
- ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
3714
- GGML_OP_DIV, GGML_OP_RESHAPE });
3715
- clamp = cgraph->nodes[i + ops.size() - 3];
3716
- }
3717
- if (args.scale) {
3718
- ops.insert(ops.end(), { GGML_OP_SCALE });
3719
- scale = cgraph->nodes[i + ops.size() - 1];
3720
- }
3721
-
3722
- weights = cgraph->nodes[i + ops.size() - 1];
3723
- out_nodes[1] = i + ops.size() - 1;
3724
-
3725
- if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
3726
- ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
3727
- ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
3728
- ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
3729
- i += ops.size() - 1;
3730
- continue;
3731
- }
3732
- } else if (!args.norm && !args.prob_bias) {
3733
- //special case gpt-oss, no norm, no bias.
3734
- ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
3735
- GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
3736
- weights = cgraph->nodes[i + 5];
3737
- ids = cgraph->nodes[i + 1];
3738
- const ggml_tensor * softmax = cgraph->nodes[i + 4];
3739
-
3740
- int out_nodes[2] = { i + 1, i + 5 };
3741
- if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
3742
- ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
3743
- ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
3744
- ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
3745
- i += ops.size() - 1;
3746
- continue;
3747
- }
3748
- }
3749
- }
3750
- }
3751
-
3752
- if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3753
- ggml_tensor * rope = cgraph->nodes[i];
3754
- ggml_tensor * set_rows = cgraph->nodes[i + 2];
3755
-
3756
- ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
3757
- i += 2;
3758
- continue;
3759
- }
3760
-
3761
- if (node->op == GGML_OP_ADD) {
3762
- int n_fuse = 0;
3763
- ggml_op ops[8];
3764
- std::fill(ops, ops + 8, GGML_OP_ADD);
3765
-
3766
- for (; n_fuse <= 6; ++n_fuse){
3767
- if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
3768
- break;
3769
- }
3770
- if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
3771
- break;
3772
- }
3773
- if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
3774
- break;
3775
- }
3776
- }
3777
-
3778
- n_fuse++;
3779
-
3780
- if (n_fuse > 1) {
3781
- ggml_tensor fused_add_node;
3782
- memcpy(&fused_add_node, node, sizeof(ggml_tensor));
3783
- for (int j = 0; j < n_fuse - 1; ++j) {
3784
- fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
3785
- }
3786
- fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
3787
- ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
3788
- i += n_fuse - 1;
3789
-
3790
- continue;
3791
- }
3792
- }
3793
-
3794
- bool fused_mul_mat_vec = false;
3795
- int fused_node_count = 0;
3796
-
3797
- for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3798
- const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
3799
-
3800
- if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
3801
- ggml_tensor * glu = cgraph->nodes[i + 4];
3802
- ggml_tensor * gate_bias_n = glu->src[0];
3803
- ggml_tensor * up_bias_n = glu->src[1];
3804
-
3805
- //we don't assume the order for {gate, up}. Instead infer it from the bias tensor
3806
- ggml_tensor * gate_n = nullptr;
3807
- ggml_tensor * up_n = nullptr;
3808
-
3809
- if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
3810
- gate_n = cgraph->nodes[i];
3811
- up_n = cgraph->nodes[i + 2];
3812
- } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
3813
- gate_n = cgraph->nodes[i + 2];
3814
- up_n = cgraph->nodes[i];
3815
- } else {
3816
- continue;
3817
- }
3818
-
3819
- auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
3820
- if (op_bias == GGML_OP_ADD) {
3821
- if (bias_node->src[0] == mul_node) {
3822
- return bias_node->src[1];
3823
- }
3824
- if (bias_node->src[1] == mul_node) {
3825
- return bias_node->src[0];
3826
- }
3827
- return (ggml_tensor *) nullptr;
3828
- }
3829
- GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
3830
- GGML_ASSERT(bias_node->src[0] == mul_node);
3831
- return bias_node->src[1];
3832
- };
3833
-
3834
- ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
3835
- ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
3836
-
3837
- if (!up_bias_tensor || !gate_bias_tensor) {
3838
- continue;
3839
- }
3840
-
3841
- // we don't support repeating adds
3842
- if (bias_op == GGML_OP_ADD &&
3843
- (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
3844
- !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
3845
- continue;
3846
- }
3847
-
3848
- const ggml_tensor * src0 = up_n->src[0];
3849
- const ggml_tensor * src1 = up_n->src[1];
3850
- const ggml_tensor * ids = up_n->src[2];
3851
-
3852
- if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
3853
- ggml_cuda_mm_fusion_args_host fusion_data{};
3854
- fusion_data.gate = gate_n->src[0];
3855
- fusion_data.x_bias = up_bias_tensor;
3856
- fusion_data.gate_bias = gate_bias_tensor;
3857
- fusion_data.glu_op = ggml_get_glu_op(glu);
3858
-
3859
- ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3860
- fused_mul_mat_vec = true;
3861
- fused_node_count = 5;
3862
- break;
3863
- }
3864
-
3865
- if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
3866
- ggml_cuda_mm_fusion_args_host fusion_data{};
3867
- fusion_data.gate = gate_n->src[0];
3868
- fusion_data.x_bias = up_bias_tensor;
3869
- fusion_data.gate_bias = gate_bias_tensor;
3870
- fusion_data.glu_op = ggml_get_glu_op(glu);
3871
-
3872
- ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3873
- fused_mul_mat_vec = true;
3874
- fused_node_count = 5;
3875
- break;
3876
- }
3877
- } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
3878
- ggml_tensor * glu = cgraph->nodes[i + 2];
3879
- ggml_tensor * gate = glu->src[0];
3880
- ggml_tensor * up = glu->src[1];
3881
-
3882
- bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
3883
- || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
3884
-
3885
- if (!ok) continue;
3886
-
3887
- const ggml_tensor * src0 = up->src[0];
3888
- const ggml_tensor * src1 = up->src[1];
3889
- const ggml_tensor * ids = up->src[2];
3890
-
3891
- if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
3892
- ggml_cuda_mm_fusion_args_host fusion_data{};
3893
- fusion_data.gate = gate->src[0];
3894
- fusion_data.glu_op = ggml_get_glu_op(glu);
3895
-
3896
- ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3897
- fused_mul_mat_vec = true;
3898
- fused_node_count = 3;
3899
- break;
3900
- }
3901
-
3902
- if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
3903
- ggml_cuda_mm_fusion_args_host fusion_data{};
3904
- fusion_data.gate = gate->src[0];
3905
- fusion_data.glu_op = ggml_get_glu_op(glu);
3906
-
3907
- ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3908
- fused_mul_mat_vec = true;
3909
- fused_node_count = 3;
3910
- break;
3911
- }
3912
- }
3913
- }
3914
-
3915
- if (fused_mul_mat_vec) {
3916
- i += fused_node_count - 1;
3917
- continue;
3918
- }
3919
-
3920
- fused_mul_mat_vec = false;
3921
- fused_node_count = 0;
3922
-
3923
- for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3924
- const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
3925
-
3926
- if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
3927
- continue;
3928
- }
3929
-
3930
- ggml_tensor * mm_node = cgraph->nodes[i];
3931
- ggml_tensor * bias_node = cgraph->nodes[i + 1];
4424
+ int nodes_to_skip = ggml_cuda_try_fuse(cuda_ctx, cgraph, i);
3932
4425
 
3933
- ggml_tensor * bias_tensor = nullptr;
3934
- if (bias_op == GGML_OP_ADD) {
3935
- if (bias_node->src[0] == mm_node) {
3936
- bias_tensor = bias_node->src[1];
3937
- } else if (bias_node->src[1] == mm_node) {
3938
- bias_tensor = bias_node->src[0];
3939
- } else {
3940
- continue;
3941
- }
3942
- } else {
3943
- if (bias_node->src[0] != mm_node) {
3944
- continue;
3945
- }
3946
- bias_tensor = bias_node->src[1];
3947
- }
3948
-
3949
- const ggml_tensor * src0 = mm_node->src[0];
3950
- const ggml_tensor * src1 = mm_node->src[1];
3951
- const ggml_tensor * ids = mm_node->src[2];
3952
-
3953
- if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
3954
- continue;
3955
- }
3956
-
3957
- if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
3958
- continue;
3959
- }
3960
-
3961
- ggml_cuda_mm_fusion_args_host fusion_data{};
3962
- fusion_data.x_bias = bias_tensor;
3963
-
3964
- if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
3965
- ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
3966
- fused_mul_mat_vec = true;
3967
- fused_node_count = 2;
3968
- break;
3969
- }
3970
-
3971
- if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
3972
- ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
3973
- fused_mul_mat_vec = true;
3974
- fused_node_count = 2;
3975
- break;
3976
- }
3977
- }
3978
-
3979
- if (fused_mul_mat_vec) {
3980
- i += fused_node_count - 1;
3981
- continue;
3982
- }
3983
-
3984
- if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
3985
- ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
3986
- i += 2;
3987
- continue;
3988
- }
3989
-
3990
- if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
3991
- ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
3992
- i++;
3993
- continue;
3994
- }
3995
-
3996
- if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
3997
- ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]);
3998
- i++;
3999
- continue;
4000
- }
4001
-
4002
- if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
4003
- ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
4004
- ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
4005
- ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]);
4006
- i++;
4007
- continue;
4008
- }
4009
-
4010
- if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
4011
- i += 2;
4012
- ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
4013
- continue;
4014
- }
4426
+ if (nodes_to_skip != 0) {
4427
+ i += nodes_to_skip;
4428
+ continue;
4015
4429
  }
4016
4430
  #ifndef NDEBUG
4017
4431
  assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
@@ -4425,6 +4839,8 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
4425
4839
  /* .free = */ ggml_backend_cuda_free,
4426
4840
  /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
4427
4841
  /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
4842
+ /* .set_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async,
4843
+ /* .get_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async,
4428
4844
  /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async,
4429
4845
  /* .synchronize = */ ggml_backend_cuda_synchronize,
4430
4846
  /* .graph_plan_create = */ NULL,
@@ -4500,14 +4916,6 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
4500
4916
 
4501
4917
  // backend device
4502
4918
 
4503
- struct ggml_backend_cuda_device_context {
4504
- int device;
4505
- std::string name;
4506
- std::string description;
4507
- std::string pci_bus_id;
4508
- int op_offload_min_batch_size;
4509
- };
4510
-
4511
4919
  static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
4512
4920
  ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
4513
4921
  return ctx->name.c_str();
@@ -4596,6 +5004,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k
4596
5004
 
4597
5005
  static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
4598
5006
  ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
5007
+
5008
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
5009
+ std::lock_guard<std::mutex> lock(ctx->device_mutex);
5010
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
5011
+
4599
5012
  ggml_cuda_set_device(ctx->device);
4600
5013
  CUDA_CHECK(cudaMemGetInfo(free, total));
4601
5014
 
@@ -4622,11 +5035,24 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
4622
5035
  }
4623
5036
  #endif // defined(__linux__)
4624
5037
 
5038
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
5039
+ // If no backends or buffers are active, the cudaMemGetInfo call above lazily created a CUDA
5040
+ // context that permanently consumes VRAM. Reset the device to free it.
5041
+ if (ctx->active_count == 0) {
5042
+ CUDA_CHECK(cudaDeviceReset());
5043
+ }
5044
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
4625
5045
  }
4626
5046
 
4627
5047
  static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
4628
- GGML_UNUSED(dev);
4629
- return GGML_BACKEND_DEVICE_TYPE_GPU;
5048
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *) dev->context;
5049
+
5050
+ cudaDeviceProp prop;
5051
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));
5052
+
5053
+ return prop.integrated
5054
+ ? GGML_BACKEND_DEVICE_TYPE_IGPU
5055
+ : GGML_BACKEND_DEVICE_TYPE_GPU;
4630
5056
  }
4631
5057
 
4632
5058
  static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
@@ -4775,12 +5201,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4775
5201
  switch (a->type) {
4776
5202
  case GGML_TYPE_F32:
4777
5203
  case GGML_TYPE_F16:
5204
+ case GGML_TYPE_Q1_0:
4778
5205
  case GGML_TYPE_Q4_0:
4779
5206
  case GGML_TYPE_Q4_1:
4780
5207
  case GGML_TYPE_Q5_0:
4781
5208
  case GGML_TYPE_Q5_1:
4782
5209
  case GGML_TYPE_Q8_0:
4783
5210
  case GGML_TYPE_MXFP4:
5211
+ case GGML_TYPE_NVFP4:
4784
5212
  case GGML_TYPE_Q2_K:
4785
5213
  case GGML_TYPE_Q3_K:
4786
5214
  case GGML_TYPE_Q4_K:
@@ -4811,6 +5239,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4811
5239
  case GGML_TYPE_F32:
4812
5240
  case GGML_TYPE_BF16:
4813
5241
  case GGML_TYPE_I32:
5242
+ case GGML_TYPE_Q1_0:
4814
5243
  case GGML_TYPE_Q4_0:
4815
5244
  case GGML_TYPE_Q4_1:
4816
5245
  case GGML_TYPE_Q5_0:
@@ -4916,7 +5345,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4916
5345
  case GGML_OP_CONCAT:
4917
5346
  {
4918
5347
  ggml_type src0_type = op->src[0]->type;
4919
- return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
5348
+ ggml_type src1_type = op->src[1]->type;
5349
+ return src0_type == src1_type &&
5350
+ src0_type == op->type &&
5351
+ !ggml_is_quantized(src0_type) &&
5352
+ ggml_blck_size(src0_type) == 1 &&
5353
+ (ggml_type_size(src0_type) == 1 ||
5354
+ ggml_type_size(src0_type) == 2 ||
5355
+ ggml_type_size(src0_type) == 4 ||
5356
+ ggml_type_size(src0_type) == 8);
4920
5357
  } break;
4921
5358
  case GGML_OP_CONV_TRANSPOSE_1D:
4922
5359
  {
@@ -4942,12 +5379,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4942
5379
  case GGML_OP_VIEW:
4943
5380
  case GGML_OP_PERMUTE:
4944
5381
  case GGML_OP_TRANSPOSE:
4945
- case GGML_OP_ADD:
4946
5382
  case GGML_OP_ADD_ID:
4947
5383
  case GGML_OP_ADD1:
4948
- case GGML_OP_SUB:
4949
- case GGML_OP_MUL:
4950
- case GGML_OP_DIV:
4951
5384
  case GGML_OP_SCALE:
4952
5385
  case GGML_OP_SQR:
4953
5386
  case GGML_OP_SQRT:
@@ -4956,6 +5389,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4956
5389
  case GGML_OP_CLAMP:
4957
5390
  case GGML_OP_LOG:
4958
5391
  return true;
5392
+ case GGML_OP_ADD:
5393
+ case GGML_OP_SUB:
5394
+ case GGML_OP_MUL:
5395
+ case GGML_OP_DIV:
5396
+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
5397
+ (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
5398
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
4959
5399
  case GGML_OP_SSM_SCAN: {
4960
5400
  if (op->src[3]->ne[0] == 1) {
4961
5401
  // Mamba2
@@ -5211,6 +5651,15 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
5211
5651
 
5212
5652
  static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
5213
5653
  GGML_UNUSED(reg);
5654
+ if (strcmp(name, "ggml_backend_comm_init") == 0) {
5655
+ return (void *)ggml_backend_cuda_comm_init;
5656
+ }
5657
+ if (strcmp(name, "ggml_backend_comm_free") == 0) {
5658
+ return (void *)ggml_backend_cuda_comm_free;
5659
+ }
5660
+ if (strcmp(name, "ggml_backend_comm_allreduce_tensor") == 0) {
5661
+ return (void *)ggml_backend_cuda_comm_allreduce_tensor;
5662
+ }
5214
5663
  if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
5215
5664
  return (void *)ggml_backend_cuda_split_buffer_type;
5216
5665
  }
@@ -5254,9 +5703,12 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
5254
5703
  CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
5255
5704
  dev_ctx->description = prop.name;
5256
5705
 
5257
- char pci_bus_id[16] = {};
5258
- snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
5706
+ char pci_bus_id[32] = {};
5707
+ CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i));
5259
5708
  dev_ctx->pci_bus_id = pci_bus_id;
5709
+ for (char & c : dev_ctx->pci_bus_id) {
5710
+ c = std::tolower(c);
5711
+ }
5260
5712
  dev_ctx->op_offload_min_batch_size = min_batch_size;
5261
5713
 
5262
5714
  ggml_backend_dev_t dev = new ggml_backend_device {
@@ -5292,13 +5744,21 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
5292
5744
  return nullptr;
5293
5745
  }
5294
5746
 
5747
+ ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device);
5748
+
5295
5749
  ggml_backend_t cuda_backend = new ggml_backend {
5296
5750
  /* .guid = */ ggml_backend_cuda_guid(),
5297
5751
  /* .iface = */ ggml_backend_cuda_interface,
5298
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
5752
+ /* .device = */ dev,
5299
5753
  /* .context = */ ctx,
5300
5754
  };
5301
5755
 
5756
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
5757
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
5758
+ std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
5759
+ dev_ctx->active_count++;
5760
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
5761
+
5302
5762
  return cuda_backend;
5303
5763
  }
5304
5764