whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -2,6 +2,7 @@
2
2
  #include "ggml-impl.h"
3
3
  #include "ggml-backend-impl.h"
4
4
 
5
+ #include "ggml-cuda/allreduce.cuh"
5
6
  #include "ggml-cuda/common.cuh"
6
7
  #include "ggml-cuda/acc.cuh"
7
8
  #include "ggml-cuda/add-id.cuh"
@@ -23,6 +24,7 @@
23
24
  #include "ggml-cuda/diagmask.cuh"
24
25
  #include "ggml-cuda/diag.cuh"
25
26
  #include "ggml-cuda/fattn.cuh"
27
+ #include "ggml-cuda/fwht.cuh"
26
28
  #include "ggml-cuda/getrows.cuh"
27
29
  #include "ggml-cuda/im2col.cuh"
28
30
  #include "ggml-cuda/mmf.cuh"
@@ -39,6 +41,7 @@
39
41
  #include "ggml-cuda/rope.cuh"
40
42
  #include "ggml-cuda/roll.cuh"
41
43
  #include "ggml-cuda/scale.cuh"
44
+ #include "ggml-cuda/snake.cuh"
42
45
  #include "ggml-cuda/softcap.cuh"
43
46
  #include "ggml-cuda/softmax.cuh"
44
47
  #include "ggml-cuda/ssm-conv.cuh"
@@ -53,6 +56,7 @@
53
56
  #include "ggml-cuda/upscale.cuh"
54
57
  #include "ggml-cuda/wkv.cuh"
55
58
  #include "ggml-cuda/gla.cuh"
59
+ #include "ggml-cuda/gated_delta_net.cuh"
56
60
  #include "ggml-cuda/set.cuh"
57
61
  #include "ggml-cuda/set-rows.cuh"
58
62
  #include "ggml-cuda/pad_reflect_1d.cuh"
@@ -70,20 +74,23 @@
70
74
  #include <condition_variable>
71
75
  #include <cstddef>
72
76
  #include <cstdint>
73
- #include <float.h>
77
+ #include <cfloat>
74
78
  #include <initializer_list>
75
79
  #include <limits>
76
80
  #include <map>
77
81
  #include <memory>
78
82
  #include <mutex>
79
- #include <stdarg.h>
80
- #include <stdio.h>
81
- #include <stdlib.h>
83
+ #include <cstdarg>
84
+ #include <cstdio>
85
+ #include <cstdlib>
82
86
  #include <string>
83
87
  #include <vector>
84
88
 
85
89
  static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
86
90
 
91
+ #define GGML_LOG_WARN_ONCE(str) \
92
+ { static std::once_flag warn_flag; std::call_once(warn_flag, []() { GGML_LOG_WARN(str); }); }
93
+
87
94
  [[noreturn]]
88
95
  void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
89
96
  int id = -1; // in case cudaGetDevice fails
@@ -122,7 +129,10 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
122
129
  err = cudaMallocManaged(ptr, size);
123
130
  #if defined(GGML_USE_HIP)
124
131
  if (err == hipSuccess) {
125
- CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
132
+ // hipMemAdviseSetCoarseGrain is an optional performance hint;
133
+ // ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs).
134
+ (void)cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);
135
+ (void)hipGetLastError(); // clear any error
126
136
  }
127
137
 
128
138
  // fall back to cudaMalloc if not supported (e.g. on Windows)
@@ -203,7 +213,14 @@ static ggml_cuda_device_info ggml_cuda_init() {
203
213
  GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
204
214
 
205
215
  int64_t total_vram = 0;
206
- GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
216
+ for (int id = 0; id < info.device_count; ++id) {
217
+ cudaDeviceProp prop;
218
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
219
+ total_vram += prop.totalGlobalMem;
220
+ }
221
+ GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices (Total VRAM: %zu MiB):\n",
222
+ __func__, info.device_count, (size_t)(total_vram / (1024 * 1024)));
223
+ total_vram = 0;
207
224
 
208
225
  std::vector<std::pair<int, std::string>> turing_devices_without_mma;
209
226
  for (int id = 0; id < info.device_count; ++id) {
@@ -241,6 +258,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
241
258
  #else
242
259
  info.devices[id].supports_cooperative_launch = false;
243
260
  #endif // !(GGML_USE_MUSA)
261
+
244
262
  #if defined(GGML_USE_HIP)
245
263
  info.devices[id].smpbo = prop.sharedMemPerBlock;
246
264
 
@@ -255,22 +273,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
255
273
  info.devices[id].cc += prop.minor * 0x10;
256
274
  }
257
275
  }
258
- GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
276
+ GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB\n",
259
277
  id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
260
- device_vmm ? "yes" : "no", prop.warpSize);
278
+ device_vmm ? "yes" : "no", prop.warpSize,
279
+ (size_t)(prop.totalGlobalMem / (1024 * 1024)));
261
280
  #elif defined(GGML_USE_MUSA)
262
281
  // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
263
282
  info.devices[id].warp_size = 32;
264
283
  info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
265
284
  info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
266
285
  info.devices[id].cc += prop.minor * 0x10;
267
- GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
268
- id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
286
+ GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n",
287
+ id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
288
+ (size_t)(prop.totalGlobalMem / (1024 * 1024)));
269
289
  #else
270
290
  info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
271
291
  info.devices[id].cc = 100*prop.major + 10*prop.minor;
272
- GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
273
- id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
292
+ GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n",
293
+ id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
294
+ (size_t)(prop.totalGlobalMem / (1024 * 1024)));
274
295
  std::string device_name(prop.name);
275
296
  if (device_name == "NVIDIA GeForce MX450") {
276
297
  turing_devices_without_mma.push_back({ id, device_name });
@@ -285,6 +306,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
285
306
  // TODO: Check for future drivers the default scheduling strategy and
286
307
  // remove this call again when cudaDeviceScheduleSpin is default.
287
308
  if (prop.major == 12 && prop.minor == 1) {
309
+ CUDA_CHECK(cudaSetDevice(id));
288
310
  CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
289
311
  }
290
312
 
@@ -308,6 +330,22 @@ static ggml_cuda_device_info ggml_cuda_init() {
308
330
  // configure logging to stdout
309
331
  // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
310
332
 
333
+ if (getenv("GGML_CUDA_P2P") != nullptr) {
334
+ for (int id = 0; id < info.device_count; ++id) {
335
+ ggml_cuda_set_device(id);
336
+ for (int id_other = 0; id_other < info.device_count; ++id_other) {
337
+ if (id == id_other) {
338
+ continue;
339
+ }
340
+ int can_access_peer;
341
+ CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
342
+ if (can_access_peer) {
343
+ CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0));
344
+ }
345
+ }
346
+ }
347
+ }
348
+
311
349
  return info;
312
350
  }
313
351
 
@@ -336,15 +374,21 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
336
374
  }
337
375
 
338
376
  ~ggml_cuda_pool_leg() {
377
+ clear_pool();
378
+ GGML_ASSERT(pool_size == 0);
379
+ }
380
+
381
+ void clear_pool() {
339
382
  ggml_cuda_set_device(device);
340
383
  for (int i = 0; i < MAX_BUFFERS; ++i) {
341
384
  ggml_cuda_buffer & b = buffer_pool[i];
342
385
  if (b.ptr != nullptr) {
343
386
  CUDA_CHECK(cudaFree(b.ptr));
344
387
  pool_size -= b.size;
388
+ b.ptr = nullptr;
389
+ b.size = 0;
345
390
  }
346
391
  }
347
- GGML_ASSERT(pool_size == 0);
348
392
  }
349
393
 
350
394
  void * alloc(size_t size, size_t * actual_size) override {
@@ -389,7 +433,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
389
433
  size_t look_ahead_size = (size_t) (1.05 * size);
390
434
  look_ahead_size = 256 * ((look_ahead_size + 255)/256);
391
435
  ggml_cuda_set_device(device);
392
- CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
436
+ cudaError_t err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device);
437
+ if (err == cudaErrorMemoryAllocation) {
438
+ (void)cudaGetLastError();
439
+ const size_t cached_bytes = pool_size;
440
+ GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: alloc of %.2f MiB failed, flushing %.2f MiB of cached buffers and retrying\n",
441
+ device, look_ahead_size/1024.0/1024.0, cached_bytes/1024.0/1024.0);
442
+ CUDA_CHECK(cudaDeviceSynchronize());
443
+ clear_pool();
444
+ err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device);
445
+ if (err == cudaSuccess) {
446
+ GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: retry succeeded\n", device);
447
+ }
448
+ }
449
+ CUDA_CHECK(err);
393
450
  *actual_size = look_ahead_size;
394
451
  pool_size += look_ahead_size;
395
452
  #ifdef DEBUG_CUDA_MALLOC
@@ -565,6 +622,18 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() {
565
622
 
566
623
  // cuda buffer
567
624
 
625
+ struct ggml_backend_cuda_device_context {
626
+ int device;
627
+ std::string name;
628
+ std::string description;
629
+ std::string pci_bus_id;
630
+ int op_offload_min_batch_size;
631
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
632
+ std::mutex device_mutex;
633
+ int active_count = 0;
634
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
635
+ };
636
+
568
637
  struct ggml_backend_cuda_buffer_context {
569
638
  int device;
570
639
  void * dev_ptr = nullptr;
@@ -582,6 +651,13 @@ struct ggml_backend_cuda_buffer_context {
582
651
 
583
652
  static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
584
653
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
654
+
655
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
656
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context;
657
+ std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
658
+ dev_ctx->active_count--;
659
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
660
+
585
661
  delete ctx;
586
662
  }
587
663
 
@@ -616,26 +692,46 @@ static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer
616
692
  }
617
693
 
618
694
  static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
619
- ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
695
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context;
620
696
 
621
697
  ggml_cuda_set_device(ctx->device);
622
- CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
698
+ CUDA_CHECK(cudaMemsetAsync((char *) tensor->data + offset, value, size, cudaStreamPerThread));
623
699
  CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
624
700
  }
625
701
 
626
702
  static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
627
- ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
703
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context;
628
704
 
629
705
  ggml_cuda_set_device(ctx->device);
630
- CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
706
+ CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
631
707
  CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
632
708
  }
633
709
 
634
710
  static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
711
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context;
712
+
713
+ ggml_cuda_set_device(ctx->device);
714
+ CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
715
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
716
+ }
717
+
718
+ static void ggml_backend_cuda_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data,
719
+ size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
720
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context;
721
+
722
+ ggml_cuda_set_device(ctx->device);
723
+ CUDA_CHECK(cudaMemcpy2DAsync(
724
+ (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cudaStreamPerThread));
725
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
726
+ }
727
+
728
+ static void ggml_backend_cuda_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data,
729
+ size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
635
730
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
636
731
 
637
732
  ggml_cuda_set_device(ctx->device);
638
- CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
733
+ CUDA_CHECK(cudaMemcpy2DAsync(
734
+ data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cudaStreamPerThread));
639
735
  CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
640
736
  }
641
737
 
@@ -675,6 +771,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
675
771
  /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor,
676
772
  /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
677
773
  /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
774
+ /* .set_tensor_2d = */ ggml_backend_cuda_buffer_set_tensor_2d,
775
+ /* .get_tensor_2d = */ ggml_backend_cuda_buffer_get_tensor_2d,
678
776
  /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
679
777
  /* .clear = */ ggml_backend_cuda_buffer_clear,
680
778
  /* .reset = */ NULL,
@@ -712,6 +810,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
712
810
 
713
811
  ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
714
812
 
813
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
814
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context;
815
+ std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
816
+ dev_ctx->active_count++;
817
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
818
+
715
819
  return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
716
820
  }
717
821
 
@@ -722,7 +826,11 @@ static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_ty
722
826
  }
723
827
 
724
828
  static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
725
- size_t size = ggml_nbytes(tensor);
829
+ ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *) buft->context;
830
+
831
+ size_t size = tensor->op == GGML_OP_FLASH_ATTN_EXT
832
+ ? ggml_cuda_flash_attn_ext_get_alloc_size(buft_ctx->device, tensor)
833
+ : ggml_nbytes(tensor);
726
834
  int64_t ne0 = tensor->ne[0];
727
835
 
728
836
  if (ggml_is_quantized(tensor->type)) {
@@ -733,8 +841,6 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t
733
841
  }
734
842
 
735
843
  return size;
736
-
737
- GGML_UNUSED(buft);
738
844
  }
739
845
 
740
846
  static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
@@ -987,6 +1093,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
987
1093
  /* .memset_tensor = */ NULL,
988
1094
  /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
989
1095
  /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor,
1096
+ /* .set_tensor_2d = */ NULL,
1097
+ /* .get_tensor_2d = */ NULL,
990
1098
  /* .cpy_tensor = */ NULL,
991
1099
  /* .clear = */ ggml_backend_cuda_split_buffer_clear,
992
1100
  /* .reset = */ NULL,
@@ -1063,6 +1171,295 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte
1063
1171
  /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
1064
1172
  };
1065
1173
 
1174
+ // Communication context for multi-GPU AllReduce during tensor parallelism.
1175
+ //
1176
+ // Created once per meta backend instance. Resources for the selected mode
1177
+ // (NCCL communicators or the internal AllReduce pipeline) are initialised
1178
+ // eagerly during comm_init so any init failure surfaces at startup rather
1179
+ // than mid-run.
1180
+ struct ggml_backend_cuda_comm_context {
1181
+ using try_allreduce_fn = bool(*)(ggml_backend_cuda_comm_context *, struct ggml_tensor **);
1182
+
1183
+ std::vector<ggml_backend_t> backends;
1184
+ std::vector<int> dev_ids;
1185
+
1186
+ // Set by the init chain (comm_init_{nccl, internal, none}) to one of
1187
+ // try_allreduce_{nccl, internal, butterfly}. nccl needs `comms`,
1188
+ // internal needs `ar_pipeline`, butterfly needs nothing. Per-call
1189
+ // failures return false; the meta backend's generic implementation then
1190
+ // handles that call.
1191
+ try_allreduce_fn try_allreduce = nullptr;
1192
+
1193
+ ggml_cuda_ar_pipeline * ar_pipeline = nullptr;
1194
+
1195
+ #ifdef GGML_USE_NCCL
1196
+ std::vector<ncclComm_t> comms;
1197
+ #endif // GGML_USE_NCCL
1198
+
1199
+ ~ggml_backend_cuda_comm_context() {
1200
+ #ifdef GGML_USE_NCCL
1201
+ for (ncclComm_t comm : comms) {
1202
+ NCCL_CHECK(ncclCommDestroy(comm));
1203
+ }
1204
+ #endif // GGML_USE_NCCL
1205
+ ggml_cuda_ar_pipeline_free(ar_pipeline);
1206
+ }
1207
+ };
1208
+
1209
+ #ifdef GGML_USE_NCCL
1210
+ // AllReduce via NCCL. Reduces as FP32 for small tensors and BF16 for large
1211
+ // tensors (bandwidth-bound), then converts back to FP32.
1212
+ static bool ggml_backend_cuda_comm_allreduce_nccl(
1213
+ ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) {
1214
+ const int64_t ne = ggml_nelements(tensors[0]);
1215
+ // FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0
1216
+ // This then causes a crash in this function
1217
+ if (ne == 0) {
1218
+ return true;
1219
+ }
1220
+
1221
+ const size_t n_backends = comm_ctx->backends.size();
1222
+
1223
+ for (size_t i = 0; i < n_backends; ++i) {
1224
+ GGML_ASSERT(tensors[i] != nullptr);
1225
+ GGML_ASSERT(ggml_nelements(tensors[i]) == ne);
1226
+ GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i]));
1227
+ }
1228
+
1229
+ // For small tensors, simply reduce them as FP32.
1230
+ // The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0.
1231
+ if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) {
1232
+ for (size_t i = 0; i < n_backends; ++i) {
1233
+ if ((tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
1234
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
1235
+ ggml_cuda_set_device(cuda_ctx->device);
1236
+ CUDA_CHECK(cudaMemsetAsync(tensors[i]->data, 0, ggml_nbytes(tensors[i]), cuda_ctx->stream()));
1237
+ }
1238
+ }
1239
+ NCCL_CHECK(ncclGroupStart());
1240
+ for (size_t i = 0; i < n_backends; ++i) {
1241
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
1242
+ NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, comm_ctx->comms[i], cuda_ctx->stream()));
1243
+ }
1244
+ NCCL_CHECK(ncclGroupEnd());
1245
+ return true;
1246
+ }
1247
+
1248
+ // For large tensors it's faster to compress them to BF16 for the reduction:
1249
+ to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32);
1250
+ to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
1251
+
1252
+ ggml_cuda_pool_alloc<nv_bfloat16> tmp[GGML_CUDA_MAX_DEVICES];
1253
+ for (size_t i = 0; i < n_backends; ++i) {
1254
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
1255
+ tmp[i].pool = &cuda_ctx->pool();
1256
+ tmp[i].alloc(ne);
1257
+
1258
+ ggml_cuda_set_device(cuda_ctx->device);
1259
+ if (tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) {
1260
+ to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream());
1261
+ } else {
1262
+ CUDA_CHECK(cudaMemsetAsync(tmp[i].get(), 0, ne * sizeof(nv_bfloat16), cuda_ctx->stream()));
1263
+ }
1264
+ CUDA_CHECK(cudaGetLastError());
1265
+ }
1266
+
1267
+ NCCL_CHECK(ncclGroupStart());
1268
+ for (size_t i = 0; i < n_backends; ++i) {
1269
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
1270
+ NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, comm_ctx->comms[i], cuda_ctx->stream()));
1271
+ }
1272
+ NCCL_CHECK(ncclGroupEnd());
1273
+
1274
+ for (size_t i = 0; i < n_backends; ++i) {
1275
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
1276
+
1277
+ ggml_cuda_set_device(cuda_ctx->device);
1278
+ to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream());
1279
+ CUDA_CHECK(cudaGetLastError());
1280
+ }
1281
+
1282
+ return true;
1283
+ }
1284
+ #endif // GGML_USE_NCCL
1285
+
1286
+ // Run the internal AR pipeline. Returns false on unsupported / failed input
1287
+ // -- the caller decides whether to abort (env-forced) or fall back silently.
1288
+ static bool ggml_backend_cuda_comm_allreduce_internal(
1289
+ ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) {
1290
+ GGML_ASSERT(comm_ctx->ar_pipeline != nullptr);
1291
+
1292
+ const size_t n_backends = comm_ctx->backends.size();
1293
+ GGML_ASSERT(n_backends == 2);
1294
+ GGML_ASSERT(tensors[0] != nullptr);
1295
+
1296
+ const int64_t ne = ggml_nelements(tensors[0]);
1297
+ const ggml_type type = tensors[0]->type;
1298
+
1299
+ if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16 && type != GGML_TYPE_BF16) {
1300
+ GGML_LOG_DEBUG("%s: internal unsupported: type=%d\n", __func__, (int) type);
1301
+ return false;
1302
+ }
1303
+
1304
+ if (ne == 0) {
1305
+ return true;
1306
+ }
1307
+
1308
+ for (size_t i = 0; i < n_backends; ++i) {
1309
+ if (tensors[i] == nullptr) {
1310
+ GGML_LOG_ERROR("%s: internal failed: tensor[%zu] is null\n", __func__, i);
1311
+ return false;
1312
+ }
1313
+ if (ggml_nelements(tensors[i]) != ne || tensors[i]->type != type) {
1314
+ GGML_LOG_ERROR("%s: internal failed: tensor[%zu] ne=%" PRId64 " type=%d expected ne=%" PRId64 " type=%d\n",
1315
+ __func__, i, ggml_nelements(tensors[i]), (int) tensors[i]->type, ne, (int) type);
1316
+ return false;
1317
+ }
1318
+ if (!ggml_is_contiguously_allocated(tensors[i])) {
1319
+ GGML_LOG_DEBUG("%s: internal unsupported: tensor[%zu] is not contiguously allocated: ne=%" PRId64 " nbytes=%zu packed=%zu type=%d\n",
1320
+ __func__, i, ne, ggml_nbytes(tensors[i]),
1321
+ (size_t) ne * ggml_type_size(type) / ggml_blck_size(type), (int) type);
1322
+ return false;
1323
+ }
1324
+ if (((uintptr_t) tensors[i]->data & 0xF) != 0) {
1325
+ GGML_LOG_DEBUG("%s: internal unsupported: tensor[%zu] data pointer is not 16-byte aligned: %p type=%d ne=%" PRId64 "\n",
1326
+ __func__, i, tensors[i]->data, (int) type, ne);
1327
+ return false;
1328
+ }
1329
+ GGML_ASSERT((ggml_nbytes(tensors[i]) & 0xF) == 0);
1330
+ }
1331
+
1332
+ return ggml_cuda_ar_allreduce(comm_ctx->ar_pipeline, comm_ctx->backends.data(), tensors);
1333
+ }
1334
+
1335
+ // ---------------------------------------------------------------------------
1336
+ // Per-call dispatch -- three variants, one per backend. Each is set as
1337
+ // comm_ctx->try_allreduce by the matching init step. Per-call failure
1338
+ // returns false; the meta backend's generic implementation handles that call.
1339
+ // ---------------------------------------------------------------------------
1340
+
1341
+ #ifdef GGML_USE_NCCL
1342
+ static bool ggml_backend_cuda_comm_try_allreduce_nccl(
1343
+ ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) {
1344
+ return ggml_backend_cuda_comm_allreduce_nccl(comm_ctx, tensors);
1345
+ }
1346
+ #endif // GGML_USE_NCCL
1347
+
1348
+ static bool ggml_backend_cuda_comm_try_allreduce_internal(
1349
+ ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) {
1350
+ return ggml_backend_cuda_comm_allreduce_internal(comm_ctx, tensors);
1351
+ }
1352
+
1353
+ static bool ggml_backend_cuda_comm_try_allreduce_butterfly(
1354
+ ggml_backend_cuda_comm_context *, struct ggml_tensor **) {
1355
+ return false;
1356
+ }
1357
+
1358
+ static void ggml_backend_cuda_comm_free(void * comm_ctx_v) {
1359
+ if (comm_ctx_v == nullptr) {
1360
+ return;
1361
+ }
1362
+ delete static_cast<ggml_backend_cuda_comm_context *>(comm_ctx_v);
1363
+ }
1364
+
1365
+ // ---------------------------------------------------------------------------
1366
+ // Init -- chained nccl -> internal -> none. Each step tries to bring up its
1367
+ // resource; on failure it warns and recurses into the next step.
1368
+ // ---------------------------------------------------------------------------
1369
+ static void ggml_backend_cuda_comm_init_none(ggml_backend_cuda_comm_context * ret) {
1370
+ ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_butterfly;
1371
+ }
1372
+
1373
+ static void ggml_backend_cuda_comm_init_internal(ggml_backend_cuda_comm_context * ret) {
1374
+ ret->ar_pipeline = ggml_cuda_ar_pipeline_init(ret->dev_ids.data(), ret->dev_ids.size());
1375
+ if (ret->ar_pipeline) {
1376
+ ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_internal;
1377
+ return;
1378
+ }
1379
+
1380
+ // Clear sticky CUDA error from the failed init.
1381
+ (void) cudaGetLastError();
1382
+ GGML_LOG_WARN("internal AllReduce init failed (n_devices != 2?); "
1383
+ "falling back to meta-backend butterfly\n");
1384
+ ggml_backend_cuda_comm_init_none(ret);
1385
+ }
1386
+
1387
+ static void ggml_backend_cuda_comm_init_nccl(ggml_backend_cuda_comm_context * ret) {
1388
+ #ifdef GGML_USE_NCCL
1389
+ const size_t n = ret->dev_ids.size();
1390
+ ret->comms.resize(n);
1391
+ ncclResult_t rc = ncclCommInitAll(ret->comms.data(), (int) n, ret->dev_ids.data());
1392
+ if (rc == ncclSuccess) {
1393
+ ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_nccl;
1394
+ return;
1395
+ }
1396
+
1397
+ ret->comms.clear();
1398
+ GGML_LOG_WARN("NCCL init failed (%s); falling back to internal AllReduce\n",
1399
+ ncclGetErrorString(rc));
1400
+ #else // GGML_USE_NCCL
1401
+ #ifndef GGML_USE_HIP
1402
+ GGML_LOG_WARN("NCCL not compiled in; falling back to internal AllReduce. "
1403
+ "Recompile with -DGGML_CUDA_NCCL=ON for best multi-GPU performance.\n");
1404
+ #endif // !GGML_USE_HIP
1405
+ #endif // GGML_USE_NCCL
1406
+
1407
+ ggml_backend_cuda_comm_init_internal(ret);
1408
+ }
1409
+
1410
+ // Top-level init. Picks one of the three init paths based on
1411
+ // GGML_CUDA_ALLREDUCE (or the platform default) and lets the chain handle
1412
+ // any fallback. Unrecognised env values warn and fall through to the
1413
+ // platform default.
1414
+ static void * ggml_backend_cuda_comm_init(ggml_backend_t * backends, size_t n_backends) {
1415
+ for (size_t i = 0; i < n_backends; i++) {
1416
+ if (!ggml_backend_is_cuda(backends[i])) {
1417
+ return nullptr;
1418
+ }
1419
+ }
1420
+
1421
+ auto * ret = new ggml_backend_cuda_comm_context;
1422
+ ret->backends.assign(backends, backends + n_backends);
1423
+ ret->dev_ids.reserve(n_backends);
1424
+ for (size_t i = 0; i < n_backends; i++) {
1425
+ ret->dev_ids.push_back(static_cast<ggml_backend_cuda_context *>(backends[i]->context)->device);
1426
+ }
1427
+
1428
+ const char * env = getenv("GGML_CUDA_ALLREDUCE");
1429
+ if (!env) {
1430
+ // Platform default: Linux uses NCCL, otherwise (generally Windows) internal
1431
+ #if defined(__linux__)
1432
+ ggml_backend_cuda_comm_init_nccl(ret);
1433
+ #else
1434
+ ggml_backend_cuda_comm_init_internal(ret);
1435
+ #endif // defined(__linux__)
1436
+ } else {
1437
+ std::string env_str(env);
1438
+ if (env_str == "nccl") {
1439
+ ggml_backend_cuda_comm_init_nccl(ret);
1440
+ } else if (env_str == "internal") {
1441
+ ggml_backend_cuda_comm_init_internal(ret);
1442
+ } else if (env_str == "none") {
1443
+ ggml_backend_cuda_comm_init_none(ret);
1444
+ } else {
1445
+ GGML_LOG_WARN("unknown GGML_CUDA_ALLREDUCE value: %s\n", env);
1446
+ ggml_backend_cuda_comm_init_none(ret);
1447
+ }
1448
+ }
1449
+
1450
+ return ret;
1451
+ }
1452
+
1453
+ // Top-level dispatch -- calls the function pointer chosen by comm_init.
1454
+ // Returns false to let the meta-backend's butterfly run.
1455
+ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) {
1456
+ if (comm_ctx_v == nullptr) {
1457
+ return false;
1458
+ }
1459
+ auto * comm_ctx = static_cast<ggml_backend_cuda_comm_context *>(comm_ctx_v);
1460
+ return comm_ctx->try_allreduce(comm_ctx, tensors);
1461
+ }
1462
+
1066
1463
  ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) {
1067
1464
  static std::mutex mutex;
1068
1465
  std::lock_guard<std::mutex> lock(mutex);
@@ -1118,6 +1515,12 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
1118
1515
  }
1119
1516
 
1120
1517
  static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1518
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1519
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context;
1520
+ std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
1521
+ dev_ctx->active_count--;
1522
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1523
+
1121
1524
  CUDA_CHECK(cudaFreeHost(buffer->context));
1122
1525
  }
1123
1526
 
@@ -1126,6 +1529,8 @@ static void * ggml_cuda_host_malloc(size_t size) {
1126
1529
  return nullptr;
1127
1530
  }
1128
1531
 
1532
+ ggml_cuda_set_device(0); // cudaMallocHost can create the implicit CUDA device context, make sure that this is consistently done on device 0.
1533
+
1129
1534
  void * ptr = nullptr;
1130
1535
  cudaError_t err = cudaMallocHost((void **) &ptr, size);
1131
1536
  if (err != cudaSuccess) {
@@ -1151,6 +1556,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm
1151
1556
  buffer->buft = buft;
1152
1557
  buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
1153
1558
 
1559
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1560
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context;
1561
+ std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
1562
+ dev_ctx->active_count++;
1563
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1564
+
1154
1565
  return buffer;
1155
1566
  }
1156
1567
 
@@ -1224,6 +1635,34 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
1224
1635
  }
1225
1636
  }
1226
1637
 
1638
+ struct cublas_force_compute_type {
1639
+ bool fp32 = false;
1640
+ bool fp16 = false;
1641
+ };
1642
+
1643
+ static const cublas_force_compute_type & ggml_cuda_cublas_get_force_compute_type() {
1644
+ static const cublas_force_compute_type compute_type = [] {
1645
+ cublas_force_compute_type result;
1646
+
1647
+ const bool ggml_cuda_force_cublas_compute_32f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F") != nullptr;
1648
+ const bool ggml_cuda_force_cublas_compute_16f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F") != nullptr;
1649
+
1650
+ GGML_ASSERT(ggml_cuda_force_cublas_compute_16f_env == false || ggml_cuda_force_cublas_compute_32f_env == false);
1651
+
1652
+ if (ggml_cuda_force_cublas_compute_32f_env) {
1653
+ GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F\n");
1654
+ result.fp32 = true;
1655
+ } else if (ggml_cuda_force_cublas_compute_16f_env) {
1656
+ GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F\n");
1657
+ result.fp16 = true;
1658
+ }
1659
+
1660
+ return result;
1661
+ }();
1662
+
1663
+ return compute_type;
1664
+ }
1665
+
1227
1666
  static void ggml_cuda_op_mul_mat_cublas(
1228
1667
  ggml_backend_cuda_context & ctx,
1229
1668
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
@@ -1252,7 +1691,12 @@ static void ggml_cuda_op_mul_mat_cublas(
1252
1691
  const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
1253
1692
  (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
1254
1693
 
1255
- const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
1694
+ const bool use_fp16 =
1695
+ src0->type != GGML_TYPE_NVFP4 &&
1696
+ (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
1697
+ ggml_is_contiguous(src0) &&
1698
+ row_diff == src0->ne[1] &&
1699
+ dst->op_params[0] == GGML_PREC_DEFAULT;
1256
1700
 
1257
1701
  if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1258
1702
  ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
@@ -1306,7 +1750,13 @@ static void ggml_cuda_op_mul_mat_cublas(
1306
1750
 
1307
1751
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1308
1752
 
1309
- if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1753
+ const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
1754
+
1755
+ if (!force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
1756
+ || GGML_CUDA_CC_IS_RDNA4(cc)
1757
+ || cc == GGML_CUDA_CC_VOLTA
1758
+ || force_compute_type.fp32))
1759
+ {
1310
1760
  const float alpha = 1.0f;
1311
1761
  const float beta = 0.0f;
1312
1762
  CUBLAS_CHECK(
@@ -1370,64 +1820,6 @@ static void ggml_cuda_op_mul_mat_cublas(
1370
1820
  GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size);
1371
1821
  }
1372
1822
 
1373
- static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
1374
- static bool peer_access_enabled = false;
1375
-
1376
- const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
1377
-
1378
- if (peer_access_enabled == enable_peer_access) {
1379
- return;
1380
- }
1381
-
1382
- #ifdef NDEBUG
1383
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
1384
- ggml_cuda_set_device(id);
1385
- CUDA_CHECK(cudaDeviceSynchronize());
1386
- }
1387
-
1388
- for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
1389
- ggml_cuda_set_device(id);
1390
-
1391
- for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) {
1392
- if (id == id_other) {
1393
- continue;
1394
- }
1395
- if (id != main_device && id_other != main_device) {
1396
- continue;
1397
- }
1398
-
1399
- int can_access_peer;
1400
- CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
1401
- if (can_access_peer) {
1402
- if (enable_peer_access) {
1403
- cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);
1404
- if (err != cudaErrorPeerAccessAlreadyEnabled) {
1405
- CUDA_CHECK(err);
1406
- } else {
1407
- // reset the error
1408
- (void)cudaGetLastError();
1409
- }
1410
- } else {
1411
- cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
1412
- if (err != cudaErrorPeerAccessNotEnabled) {
1413
- CUDA_CHECK(err);
1414
- } else {
1415
- // reset the error
1416
- (void)cudaGetLastError();
1417
- }
1418
- }
1419
- }
1420
- }
1421
- }
1422
-
1423
- ggml_cuda_set_device(main_device);
1424
- #endif // NDEBUG
1425
-
1426
- peer_access_enabled = enable_peer_access;
1427
-
1428
- GGML_UNUSED(main_device);
1429
- }
1430
-
1431
1823
  static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
1432
1824
  void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
1433
1825
 
@@ -1905,10 +2297,23 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1905
2297
  cudaDataType_t cu_data_type_b = traits::data_type;
1906
2298
  const void * alpha = traits::get_alpha();
1907
2299
  const void * beta = traits::get_beta();
1908
- const float alpha_f32 = 1.0f;
1909
- const float beta_f32 = 0.0f;
1910
2300
 
1911
- if (dst->op_params[0] == GGML_PREC_DEFAULT) {
2301
+ const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
2302
+
2303
+ int id = ggml_cuda_get_device();
2304
+ const int cc = ggml_cuda_info().devices[id].cc;
2305
+ static constexpr bool is_src0_type_f16 = src0_type == GGML_TYPE_F16;
2306
+
2307
+ // bf16 and fp32 are already being computed in fp32 (ensure it using static_assert),
2308
+ // so checking necessity of forced fp32 only for fp16 src0_type
2309
+ static_assert(is_src0_type_f16 || traits::compute_type == CUBLAS_COMPUTE_32F);
2310
+
2311
+ const bool need_compute_32f = is_src0_type_f16 && !force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
2312
+ || GGML_CUDA_CC_IS_RDNA4(cc)
2313
+ || cc == GGML_CUDA_CC_VOLTA
2314
+ || force_compute_type.fp32);
2315
+
2316
+ if (dst->op_params[0] == GGML_PREC_DEFAULT && !need_compute_32f) {
1912
2317
  if constexpr (src0_type == GGML_TYPE_F32) {
1913
2318
  dst_t = (char *) dst_ddf; // Direct F32 output
1914
2319
  } else {
@@ -1918,18 +2323,10 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1918
2323
  }
1919
2324
  } else {
1920
2325
  dst_t = (char *) dst_ddf;
1921
- cu_compute_type = CUBLAS_COMPUTE_32F;
1922
- cu_data_type = CUDA_R_32F;
1923
- alpha = &alpha_f32;
1924
- beta = &beta_f32;
1925
- }
1926
-
1927
- int id = ggml_cuda_get_device();
1928
- const int cc = ggml_cuda_info().devices[id].cc;
1929
- if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1930
- cu_compute_type = CUBLAS_COMPUTE_32F;
1931
- alpha = &alpha_f32;
1932
- beta = &beta_f32;
2326
+ cu_compute_type = batched_mul_mat_traits<GGML_TYPE_F32>::compute_type;
2327
+ cu_data_type = batched_mul_mat_traits<GGML_TYPE_F32>::data_type;
2328
+ alpha = batched_mul_mat_traits<GGML_TYPE_F32>::get_alpha();
2329
+ beta = batched_mul_mat_traits<GGML_TYPE_F32>::get_beta();
1933
2330
  }
1934
2331
 
1935
2332
  GGML_ASSERT(ne12 % ne02 == 0);
@@ -2214,6 +2611,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2214
2611
  use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
2215
2612
  use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
2216
2613
  use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
2614
+ use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]);
2217
2615
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2218
2616
  }
2219
2617
  } else {
@@ -2222,6 +2620,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2222
2620
  use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
2223
2621
  use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
2224
2622
  use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
2623
+ use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]);
2225
2624
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2226
2625
  }
2227
2626
 
@@ -2239,6 +2638,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2239
2638
  bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2240
2639
  bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2241
2640
 
2641
+ const int32_t hint = ggml_get_op_params_i32(dst, 1);
2642
+ if (hint == GGML_HINT_SRC0_IS_HADAMARD && !split && ggml_cuda_op_fwht(ctx, src1, dst)) {
2643
+ return;
2644
+ }
2645
+
2242
2646
  if (!split && use_mul_mat_vec_f) {
2243
2647
  // the custom F16 vector kernel can be used over batched cuBLAS GEMM
2244
2648
  // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
@@ -2277,14 +2681,22 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2277
2681
 
2278
2682
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2279
2683
 
2684
+ // [TAG_MUL_MAT_ID_CUDA_GRAPHS]
2280
2685
  if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2281
- if (ne2 == 1) {
2686
+ static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
2687
+ if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
2282
2688
  if (ggml_is_quantized(src0->type)) {
2283
- ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
2689
+ const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc);
2690
+ if (ne2 <= mmvq_mmid_max) {
2691
+ ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
2692
+ return;
2693
+ }
2284
2694
  } else {
2285
- ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
2695
+ if (GGML_CUDA_CC_IS_AMD(cc)) {
2696
+ ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
2697
+ return;
2698
+ }
2286
2699
  }
2287
- return;
2288
2700
  }
2289
2701
 
2290
2702
  if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
@@ -2298,6 +2710,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2298
2710
  }
2299
2711
  }
2300
2712
 
2713
+ // note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
2714
+ // TODO: add asserts to verify this. should work with CUDA, HIP, etc.
2301
2715
  cudaStream_t stream = ctx.stream();
2302
2716
 
2303
2717
  GGML_ASSERT(nb12 % nb11 == 0);
@@ -2413,11 +2827,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2413
2827
  }
2414
2828
 
2415
2829
  static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
2416
- // why is this here instead of mul_mat?
2417
- if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) {
2418
- ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
2419
- }
2420
-
2421
2830
  switch (dst->op) {
2422
2831
  case GGML_OP_ARGMAX:
2423
2832
  ggml_cuda_argmax(ctx, dst);
@@ -2723,6 +3132,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2723
3132
  case GGML_OP_GATED_LINEAR_ATTN:
2724
3133
  ggml_cuda_op_gated_linear_attn(ctx, dst);
2725
3134
  break;
3135
+ case GGML_OP_GATED_DELTA_NET:
3136
+ ggml_cuda_op_gated_delta_net(ctx, dst);
3137
+ break;
2726
3138
  case GGML_OP_RWKV_WKV7:
2727
3139
  ggml_cuda_op_rwkv_wkv7(ctx, dst);
2728
3140
  break;
@@ -2767,26 +3179,54 @@ static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) {
2767
3179
  static void ggml_backend_cuda_free(ggml_backend_t backend) {
2768
3180
  ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
2769
3181
 
3182
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
3183
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context;
3184
+ std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
3185
+ dev_ctx->active_count--;
3186
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
3187
+
2770
3188
  delete cuda_ctx;
2771
3189
  delete backend;
2772
3190
  }
2773
3191
 
2774
3192
  static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2775
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
3193
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
2776
3194
  ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
2777
3195
 
2778
3196
  GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
2779
3197
 
2780
- CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
3198
+ CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
2781
3199
  }
2782
3200
 
2783
3201
  static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2784
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
3202
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
3203
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
3204
+
3205
+ GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
3206
+
3207
+ CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
3208
+ }
3209
+
3210
+ static void ggml_backend_cuda_set_tensor_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data,
3211
+ size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
3212
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
3213
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
3214
+
3215
+ GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
3216
+
3217
+ CUDA_CHECK(cudaMemcpy2DAsync(
3218
+ (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cuda_ctx->stream()));
3219
+ }
3220
+
3221
+ static void ggml_backend_cuda_get_tensor_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data,
3222
+ size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
3223
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
2785
3224
  ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
2786
3225
 
2787
3226
  GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
2788
3227
 
2789
- CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
3228
+ CUDA_CHECK(cudaMemcpy2DAsync(
3229
+ data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
2790
3230
  }
2791
3231
 
2792
3232
  static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
@@ -2797,21 +3237,21 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
2797
3237
  return false;
2798
3238
  }
2799
3239
 
2800
- if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) {
3240
+ if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) {
2801
3241
  return false;
2802
3242
  }
2803
3243
 
2804
3244
  // device -> device copy
2805
- ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
2806
- ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
3245
+ ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *) backend_src->context;
3246
+ ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *) backend_dst->context;
2807
3247
 
2808
- ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
2809
- ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
3248
+ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context;
3249
+ ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context;
2810
3250
 
2811
3251
  if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
2812
3252
  #ifndef NDEBUG
2813
3253
  GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
2814
- #endif
3254
+ #endif // NDEBUG
2815
3255
  return false;
2816
3256
  }
2817
3257
 
@@ -2824,7 +3264,7 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
2824
3264
  return false;
2825
3265
  #else
2826
3266
  CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
2827
- #endif
3267
+ #endif // GGML_CUDA_NO_PEER_COPY
2828
3268
  }
2829
3269
 
2830
3270
  // record event on src stream after the copy
@@ -2858,14 +3298,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
2858
3298
  bool use_cuda_graph = true;
2859
3299
  // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2860
3300
 
2861
- const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
2862
- const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
2863
- const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
2864
- const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
2865
- const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
2866
- const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
2867
- const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
2868
-
2869
3301
  for (int i = 0; i < cgraph->n_nodes; i++) {
2870
3302
  ggml_tensor * node = cgraph->nodes[i];
2871
3303
 
@@ -2880,31 +3312,19 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
2880
3312
  #endif
2881
3313
  }
2882
3314
 
2883
- if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
2884
- use_cuda_graph = false; // This node type is not supported by CUDA graph capture
3315
+ // [TAG_MUL_MAT_ID_CUDA_GRAPHS]
3316
+ if (node->op == GGML_OP_MUL_MAT_ID) {
3317
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
3318
+ const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc);
3319
+ if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) {
3320
+ // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
3321
+ // TODO: figure out a way to enable for larger batch sizes, without hurting performance
3322
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18958
3323
+ use_cuda_graph = false;
2885
3324
  #ifndef NDEBUG
2886
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
2887
- #endif
2888
- }
2889
-
2890
- if (node->op == GGML_OP_ADD &&
2891
- node->src[1] && node->src[1]->ne[1] > 1 &&
2892
- (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
2893
- (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
2894
- strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
2895
- strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
2896
- strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
2897
- strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
2898
- strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
2899
- // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
2900
- // by means of matching node names. See
2901
- // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
2902
- // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2903
- // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2904
- use_cuda_graph = false;
2905
- #ifndef NDEBUG
2906
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
3325
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
2907
3326
  #endif
3327
+ }
2908
3328
  }
2909
3329
 
2910
3330
  if (!use_cuda_graph) {
@@ -2915,105 +3335,62 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
2915
3335
  return use_cuda_graph;
2916
3336
  }
2917
3337
 
2918
- static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
2919
- props->node_address = node->data;
2920
- props->node_op = node->op;
2921
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
2922
- props->ne[i] = node->ne[i];
2923
- props->nb[i] = node->nb[i];
2924
- }
2925
- for (int i = 0; i < GGML_MAX_SRC; i++) {
2926
- props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
2927
- }
2928
- memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
3338
+ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
3339
+ return cgraph->nodes[0];
2929
3340
  }
2930
3341
 
2931
- static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
2932
- if (node->data != props->node_address &&
2933
- node->op != GGML_OP_VIEW) {
2934
- return false;
2935
- }
2936
-
2937
- if (node->op != props->node_op) {
2938
- return false;
2939
- }
3342
+ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
3343
+ bool res = false;
2940
3344
 
2941
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
2942
- if (node->ne[i] != props->ne[i]) {
2943
- return false;
2944
- }
2945
- if (node->nb[i] != props->nb[i]) {
2946
- return false;
2947
- }
2948
- }
3345
+ const void * graph_key = ggml_cuda_graph_get_key(cgraph);
3346
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
2949
3347
 
2950
- for (int i = 0; i < GGML_MAX_SRC; i++) {
2951
- if (node->src[i] &&
2952
- node->src[i]->data != props->src_address[i] &&
2953
- node->op != GGML_OP_VIEW
2954
- ) {
2955
- return false;
2956
- }
2957
- }
2958
-
2959
- if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
2960
- memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
3348
+ if (cgraph->uid != 0 &&
3349
+ cgraph->uid == graph->uid) {
3350
+ GGML_LOG_DEBUG("CUDA Graph id %zu reused\n", cgraph->uid);
3351
+ GGML_ASSERT((int)graph->node_props.size() == cgraph->n_nodes);
2961
3352
  return false;
2962
3353
  }
2963
3354
 
2964
- return true;
2965
- }
2966
-
2967
- static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
2968
-
2969
- bool res = false;
2970
-
2971
- if (cuda_ctx->cuda_graph->instance == nullptr) {
2972
- res = true;
2973
- }
3355
+ graph->uid = cgraph->uid;
2974
3356
 
2975
3357
  // Check if the graph size has changed
2976
- if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
3358
+ if ((int)graph->node_props.size() != cgraph->n_nodes) {
2977
3359
  res = true;
2978
- cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
3360
+ graph->node_props.resize(cgraph->n_nodes);
2979
3361
  }
2980
3362
 
2981
- // Loop over nodes in GGML graph to determine if CUDA graph update is required
2982
- // and store properties to allow this comparison for the next token
2983
3363
  for (int i = 0; i < cgraph->n_nodes; i++) {
2984
- bool props_match = true;
2985
- if (!res) {
2986
- props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
2987
- }
2988
- if (!props_match) {
2989
- res = true;
3364
+ ggml_cuda_graph::node_properties prop = {};
3365
+ memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor));
3366
+
3367
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
3368
+ if (cgraph->nodes[i]->src[j]) {
3369
+ prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j]->data;
3370
+ memcpy(prop.node_src_ne[j], cgraph->nodes[i]->src[j]->ne, sizeof(prop.node_src_ne[j]));
3371
+ memcpy(prop.node_src_nb[j], cgraph->nodes[i]->src[j]->nb, sizeof(prop.node_src_nb[j]));
3372
+ }
2990
3373
  }
2991
- ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
2992
- }
2993
3374
 
2994
- for (int i = 0; i < cgraph->n_leafs; i++) {
2995
- bool props_match= true;
2996
- if (!res) {
2997
- props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
2998
- }
2999
- if (!props_match) {
3375
+ if (res || memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
3376
+ graph->node_props[i] = prop;
3000
3377
  res = true;
3001
3378
  }
3002
- ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
3003
3379
  }
3004
3380
 
3005
3381
  return res;
3006
3382
  }
3007
3383
 
3008
- static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
3384
+ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
3385
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3009
3386
 
3010
3387
  #if CUDART_VERSION >= 12000
3011
3388
  cudaGraphExecUpdateResultInfo result_info;
3012
- cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
3389
+ cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);
3013
3390
  #else
3014
3391
  cudaGraphNode_t errorNode;
3015
3392
  cudaGraphExecUpdateResult result_info;
3016
- cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
3393
+ cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);
3017
3394
  #endif // CUDART_VERSION >= 12000
3018
3395
 
3019
3396
  if (stat == cudaErrorGraphExecUpdateFailure) {
@@ -3024,14 +3401,14 @@ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_c
3024
3401
  // The pre-existing graph exec cannot be updated due to violated constraints
3025
3402
  // so instead clear error and re-instantiate
3026
3403
  (void)cudaGetLastError();
3027
- CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
3028
- cuda_ctx->cuda_graph->instance = nullptr;
3029
- CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
3404
+ CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
3405
+ graph->instance = nullptr;
3406
+ CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
3030
3407
  } else {
3031
3408
  GGML_ASSERT(stat == cudaSuccess);
3032
3409
  }
3033
3410
  }
3034
- #endif
3411
+ #endif // USE_CUDA_GRAPH
3035
3412
 
3036
3413
  static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
3037
3414
  const ggml_tensor * view,
@@ -3067,63 +3444,231 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
3067
3444
  return true;
3068
3445
  }
3069
3446
 
3070
- static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
3071
- #ifndef NDEBUG
3072
- const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
3073
- GGML_ASSERT(unary_ops.size() == num_unary);
3074
- #endif
3447
+ static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
3448
+ args.sigmoid = false;
3449
+ args.softmax = false;
3450
+ args.delayed_softmax = false;
3451
+ args.prob_bias = false;
3452
+ args.norm = false;
3075
3453
 
3076
- //TODO: remove special case once ggml_can_fuse can handle empty nodes
3077
- std::initializer_list<enum ggml_op> topk_moe_ops =
3078
- ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
3079
- std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
3080
- ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
3081
- std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
3082
- ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
3454
+ const int n_nodes = cgraph->n_nodes;
3455
+ ggml_tensor ** nodes = cgraph->nodes;
3083
3456
 
3084
- const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
3085
- const std::initializer_list<enum ggml_op> & list2) {
3086
- return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
3087
- };
3457
+ if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
3458
+ args.softmax = true;
3459
+ }
3088
3460
 
3089
- if (is_equal(topk_moe_ops_with_norm, ops) &&
3090
- ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
3091
- ggml_tensor * softmax = cgraph->nodes[node_idx];
3092
- ggml_tensor * weights = cgraph->nodes[node_idx + 9];
3093
- ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
3094
- ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
3095
- int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
3461
+ if (nodes[node_idx]->op == GGML_OP_UNARY) {
3462
+ if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
3463
+ return false;
3464
+ }
3465
+ args.sigmoid = true;
3466
+ }
3096
3467
 
3097
- if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
3098
- return true;
3468
+ if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
3469
+ args.delayed_softmax = true;
3470
+ }
3471
+
3472
+ node_idx++;
3473
+
3474
+ if (args.sigmoid || args.softmax) {
3475
+ // SOFTMAX -> RESHAPE
3476
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
3477
+ nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3478
+ return false;
3479
+ }
3480
+ ggml_tensor * probs_reshaped = nodes[node_idx];
3481
+ node_idx++;
3482
+
3483
+ if (node_idx >= n_nodes) {
3484
+ return false;
3485
+ }
3486
+
3487
+ // src of bias add is the unreshaped probs (-2 instead of -1)
3488
+ if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
3489
+ args.prob_bias = true;
3490
+ node_idx++;
3491
+ }
3492
+ // RESHAPE/ADD -> ARGSORT
3493
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
3494
+ return false;
3495
+ }
3496
+
3497
+ if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3498
+ return false;
3499
+ } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
3500
+ return false;
3501
+ }
3502
+
3503
+ node_idx++;
3504
+
3505
+ // ARGSORT-> VIEW
3506
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
3507
+ nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3508
+ return false;
3509
+ }
3510
+ node_idx++;
3511
+
3512
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
3513
+ return false;
3514
+ }
3515
+
3516
+ // GET_ROWS
3517
+ if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
3518
+ return false;
3519
+ }
3520
+ node_idx++;
3521
+ } else if (args.delayed_softmax) {
3522
+ if (node_idx - 2 < 0) {
3523
+ return false;
3524
+ }
3525
+ ggml_tensor * probs_reshaped = nodes[node_idx - 2];
3526
+
3527
+ // VIEW->ARGSORT
3528
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
3529
+ nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3530
+ return false;
3531
+ }
3532
+ node_idx++;
3533
+
3534
+ // GET_ROWS
3535
+ if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
3536
+ nodes[node_idx]->src[0] != probs_reshaped) {
3537
+ return false;
3538
+ }
3539
+ node_idx++;
3540
+
3541
+ static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
3542
+
3543
+ for (const ggml_op op : remaining_ops) {
3544
+ if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3545
+ return false;
3546
+ }
3547
+ node_idx++;
3099
3548
  }
3100
3549
  }
3101
3550
 
3102
- if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
3103
- ggml_tensor * softmax = cgraph->nodes[node_idx];
3104
- ggml_tensor * weights = cgraph->nodes[node_idx + 4];
3105
- ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
3106
- ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
3107
- int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
3551
+ // At this point we can check for norm + scale. Everything is now at least valid till the norm
3552
+ if (node_idx >= n_nodes) {
3553
+ return true;
3554
+ }
3555
+
3556
+ if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
3557
+ //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
3558
+ static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
3559
+
3560
+ args.norm = true;
3561
+ for (const ggml_op op : norm_ops) {
3562
+ if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
3563
+ node_idx++;
3564
+ } else {
3565
+ args.norm = false;
3566
+ return true;
3567
+ }
3568
+ }
3569
+
3570
+ // DIV <- CLAMP, RESHAPE
3571
+ if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
3572
+ nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
3573
+ args.norm = false;
3574
+ return true;
3575
+ }
3576
+ node_idx++;
3108
3577
 
3109
- if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
3578
+ if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3579
+ args.norm = false;
3110
3580
  return true;
3111
3581
  }
3582
+
3583
+ node_idx++;
3112
3584
  }
3113
3585
 
3114
- if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
3115
- ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
3116
- ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
3117
- ggml_tensor * weights = cgraph->nodes[node_idx + 5];
3118
- ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
3119
- ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
3120
- int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
3586
+ if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
3587
+ args.scale = true;
3588
+ }
3121
3589
 
3122
- if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
3590
+ return true;
3591
+ }
3592
+
3593
+ // returns whether the write (out) nodes overwrite the read nodes in operation
3594
+ static bool ggml_cuda_check_fusion_memory_ranges(const ggml_cgraph * cgraph,
3595
+ const int node_idx,
3596
+ const int node_count,
3597
+ const int * out_nodes,
3598
+ const int out_count,
3599
+ const bool is_topk_moe = false) {
3600
+ auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
3601
+ const int64_t a_start = (int64_t) a->data;
3602
+ const int64_t a_end = a_start + ggml_backend_buft_get_alloc_size(a->buffer->buft, a);
3603
+
3604
+ const int64_t b_start = (int64_t) b->data;
3605
+ const int64_t b_end = b_start + ggml_backend_buft_get_alloc_size(b->buffer->buft, b);
3606
+
3607
+ if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
3123
3608
  return true;
3124
3609
  }
3610
+
3611
+ return false;
3612
+ };
3613
+
3614
+ bool is_ok = true;
3615
+ // exception for topk-moe, as each row is read entirely before writing
3616
+ if (ggml_nrows(cgraph->nodes[node_idx]) == 1 && is_topk_moe) {
3617
+ return true;
3618
+ }
3619
+
3620
+ for (int i = 0; i < out_count; ++i) {
3621
+ const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
3622
+
3623
+ for (int j = node_idx; j < node_idx + node_count; ++j) {
3624
+ // Loop over all srcs of all nodes in the fusion. If the src overlaps
3625
+ // the destination and the src is not an intermediate node that's being
3626
+ // elided, then disable fusion.
3627
+
3628
+ for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3629
+ const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
3630
+
3631
+ if (!src || src->op == GGML_OP_NONE) {
3632
+ continue;
3633
+ }
3634
+
3635
+ if (nodes_overlap(dst, src)) {
3636
+ bool found = false;
3637
+
3638
+ for (int k = node_idx; k < j; ++k) {
3639
+ if (cgraph->nodes[k] == src) {
3640
+ found = true;
3641
+ break;
3642
+ }
3643
+ }
3644
+
3645
+ if (!found) {
3646
+ is_ok = false;
3647
+ break;
3648
+ }
3649
+ }
3650
+ }
3651
+ }
3125
3652
  }
3126
3653
 
3654
+ return is_ok;
3655
+ }
3656
+
3657
+
3658
+ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
3659
+ int node_idx,
3660
+ std::initializer_list<enum ggml_op> ops,
3661
+ std::initializer_list<enum ggml_unary_op> unary_ops) {
3662
+ #ifndef NDEBUG
3663
+ const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
3664
+ GGML_ASSERT(unary_ops.size() == num_unary);
3665
+ #endif
3666
+
3667
+ const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
3668
+ const std::initializer_list<enum ggml_op> & list2) {
3669
+ return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
3670
+ };
3671
+
3127
3672
  std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
3128
3673
  std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
3129
3674
 
@@ -3139,7 +3684,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
3139
3684
  const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
3140
3685
 
3141
3686
  if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
3142
- return true;
3687
+ int out_nodes[] = { node_idx + 4 };
3688
+ return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1);
3143
3689
  }
3144
3690
  }
3145
3691
 
@@ -3150,7 +3696,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
3150
3696
  const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
3151
3697
 
3152
3698
  if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
3153
- return true;
3699
+ int out_nodes[] = { node_idx + 2 };
3700
+ return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1);
3154
3701
  }
3155
3702
  }
3156
3703
 
@@ -3200,7 +3747,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
3200
3747
  return false;
3201
3748
  }
3202
3749
 
3203
- //rms_norm kernel assumes contigous rows
3750
+ //rms_norm kernel assumes contiguous rows
3204
3751
  if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
3205
3752
  return false;
3206
3753
  }
@@ -3212,6 +3759,98 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
3212
3759
  return true;
3213
3760
  }
3214
3761
 
3762
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY
3763
+ && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
3764
+ const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
3765
+ const ggml_tensor * silu = cgraph->nodes[node_idx+1];
3766
+ if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) {
3767
+ return false;
3768
+ }
3769
+
3770
+ if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
3771
+ return false;
3772
+ }
3773
+
3774
+ return true;
3775
+ }
3776
+
3777
+ if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_ADD
3778
+ && ops.begin()[2] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
3779
+ const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
3780
+ const ggml_tensor * add = cgraph->nodes[node_idx+1];
3781
+ const ggml_tensor * silu = cgraph->nodes[node_idx+2];
3782
+ if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) {
3783
+ return false;
3784
+ }
3785
+
3786
+ if (ssm_conv->type != GGML_TYPE_F32 || add->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
3787
+ return false;
3788
+ }
3789
+
3790
+ // ADD must consume ssm_conv's output and broadcast a 1-D channel-wise bias.
3791
+ const ggml_tensor * bias = (add->src[0] == ssm_conv) ? add->src[1] : add->src[0];
3792
+ if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) {
3793
+ return false;
3794
+ }
3795
+ if (ggml_nelements(bias) != ssm_conv->ne[0] || bias->ne[0] != ssm_conv->ne[0]) {
3796
+ return false;
3797
+ }
3798
+
3799
+ return true;
3800
+ }
3801
+
3802
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
3803
+ && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
3804
+ const ggml_tensor * unary = cgraph->nodes[node_idx];
3805
+ const ggml_tensor * mul = cgraph->nodes[node_idx+1];
3806
+
3807
+ if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) {
3808
+ return false;
3809
+ }
3810
+
3811
+ if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
3812
+ return false;
3813
+ }
3814
+
3815
+ if (unary->type != mul->type) {
3816
+ return false;
3817
+ }
3818
+
3819
+ const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0];
3820
+ if (other->type != unary->type) {
3821
+ return false;
3822
+ }
3823
+ if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) {
3824
+ return false;
3825
+ }
3826
+
3827
+ return true;
3828
+ }
3829
+
3830
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_SQR
3831
+ && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_RELU) {
3832
+ const ggml_tensor * unary = cgraph->nodes[node_idx];
3833
+ const ggml_tensor * sqr = cgraph->nodes[node_idx+1];
3834
+
3835
+ if (ggml_get_unary_op(unary) != GGML_UNARY_OP_RELU) {
3836
+ return false;
3837
+ }
3838
+
3839
+ if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
3840
+ return false;
3841
+ }
3842
+
3843
+ if (unary->type != sqr->type) {
3844
+ return false;
3845
+ }
3846
+
3847
+ if (!ggml_is_contiguous(unary->src[0])) {
3848
+ return false;
3849
+ }
3850
+
3851
+ return true;
3852
+ }
3853
+
3215
3854
  if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
3216
3855
  && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
3217
3856
  const ggml_tensor *scale = cgraph->nodes[node_idx];
@@ -3236,7 +3875,407 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
3236
3875
  return false;
3237
3876
  }
3238
3877
 
3239
- static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
3878
+ // try and fuse nodes and return the number of nodes to skip
3879
+ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, int i) {
3880
+
3881
+ static bool disable_fusion = getenv("GGML_CUDA_DISABLE_FUSION") != nullptr && std::atoi(getenv("GGML_CUDA_DISABLE_FUSION"));
3882
+ if (disable_fusion) {
3883
+ return 0;
3884
+ }
3885
+
3886
+ ggml_tensor * node = cgraph->nodes[i];
3887
+
3888
+ //topk-moe
3889
+ if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
3890
+ cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
3891
+ ggml_cuda_topk_moe_args args;
3892
+ const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
3893
+ std::vector<ggml_op> ops;
3894
+
3895
+ if (can_fuse) {
3896
+ const ggml_tensor * logits = node->src[0];
3897
+ ggml_tensor * weights = nullptr;
3898
+ ggml_tensor * ids = nullptr;
3899
+ const ggml_tensor * bias = nullptr;
3900
+ const ggml_tensor * clamp = nullptr;
3901
+ const ggml_tensor * scale = nullptr;
3902
+
3903
+ if (!args.delayed_softmax) {
3904
+ ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
3905
+ int out_nodes[2]; // nodes which can't be elided
3906
+
3907
+ if (args.prob_bias) {
3908
+ bias = cgraph->nodes[i + 2]->src[1];
3909
+ ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, GGML_OP_VIEW,
3910
+ GGML_OP_GET_ROWS });
3911
+ out_nodes[0] = i + 4;
3912
+ ids = cgraph->nodes[i + 4];
3913
+ } else {
3914
+ ops.insert(ops.end(),
3915
+ { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS });
3916
+ out_nodes[0] = i + 3;
3917
+ ids = cgraph->nodes[i + 3];
3918
+ }
3919
+
3920
+ if (args.norm) {
3921
+ ops.insert(ops.end(),
3922
+ { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_RESHAPE });
3923
+ clamp = cgraph->nodes[i + ops.size() - 3];
3924
+ }
3925
+ if (args.scale) {
3926
+ ops.insert(ops.end(), { GGML_OP_SCALE });
3927
+ scale = cgraph->nodes[i + ops.size() - 1];
3928
+ }
3929
+
3930
+ weights = cgraph->nodes[i + ops.size() - 1];
3931
+ out_nodes[1] = i + ops.size() - 1;
3932
+
3933
+ if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
3934
+ ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
3935
+ ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) {
3936
+ ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
3937
+ return ops.size() - 1;
3938
+ }
3939
+ } else if (!args.norm && !args.prob_bias) {
3940
+ //special case gpt-oss, no norm, no bias.
3941
+ ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
3942
+ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
3943
+ weights = cgraph->nodes[i + 5];
3944
+ ids = cgraph->nodes[i + 1];
3945
+ const ggml_tensor * softmax = cgraph->nodes[i + 4];
3946
+
3947
+ int out_nodes[2] = { i + 1, i + 5 };
3948
+ if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
3949
+ ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
3950
+ ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) {
3951
+ ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
3952
+ return ops.size() - 1;
3953
+ }
3954
+ }
3955
+ }
3956
+ }
3957
+
3958
+ //RoPE + view + set-rows
3959
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3960
+ ggml_tensor * rope = cgraph->nodes[i];
3961
+ ggml_tensor * set_rows = cgraph->nodes[i + 2];
3962
+
3963
+ ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
3964
+ return 2;
3965
+ }
3966
+
3967
+ // Snake activation: y = x + sin(a*x)^2 * inv_b
3968
+ // Naive 5-op decomposition emitted by frontends: mul -> sin -> sqr -> mul -> add
3969
+ if (ggml_can_fuse_subgraph(cgraph, i,
3970
+ { GGML_OP_MUL, GGML_OP_SIN, GGML_OP_SQR, GGML_OP_MUL, GGML_OP_ADD },
3971
+ { i + 4 })) {
3972
+ const ggml_tensor * mul0 = cgraph->nodes[i];
3973
+ const ggml_tensor * sqr = cgraph->nodes[i + 2];
3974
+ const ggml_tensor * mul1 = cgraph->nodes[i + 3];
3975
+ ggml_tensor * add = cgraph->nodes[i + 4];
3976
+
3977
+ // x carries the full activation shape, a is the broadcast operand
3978
+ const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
3979
+ const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];
3980
+
3981
+ // mul1 reads sqr and inv_b in either operand order
3982
+ const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];
3983
+
3984
+ // closure check: the trailing add must read the same x as the leading mul
3985
+ const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
3986
+
3987
+ // Kernel iterates over total = T * C, so x and add must be 2D and
3988
+ // a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled.
3989
+ const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) &&
3990
+ (add->ne[2] == 1 && add->ne[3] == 1) &&
3991
+ (a->ne[2] == 1 && a->ne[3] == 1);
3992
+ const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1];
3993
+
3994
+ // x must be in the supported whitelist and every operand / intermediate
3995
+ // result must share x's type, since launch_snake casts a / inv_b as
3996
+ // float and templates the kernel on a single T. Mixed precision chains
3997
+ // fall back to the naive path.
3998
+ const ggml_tensor * sin1 = cgraph->nodes[i + 1];
3999
+ const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) &&
4000
+ (a->type == x->type) && (inv_b->type == x->type) &&
4001
+ (mul0->type == x->type) && (sin1->type == x->type) &&
4002
+ (sqr->type == x->type) && (mul1->type == x->type) &&
4003
+ (add->type == x->type);
4004
+
4005
+ if (types_ok && shape_ok && dim_ok && x_in_add == x) {
4006
+ ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add);
4007
+ return 4;
4008
+ }
4009
+ }
4010
+
4011
+ // multi-(add or mul)
4012
+ if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
4013
+ int n_fuse = 0;
4014
+ ggml_op ops[8];
4015
+ std::fill(ops, ops + 8, node->op);
4016
+
4017
+ for (; n_fuse <= 6; ++n_fuse) {
4018
+ if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
4019
+ break;
4020
+ }
4021
+ if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
4022
+ break;
4023
+ }
4024
+ if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
4025
+ break;
4026
+ }
4027
+ }
4028
+
4029
+ n_fuse++;
4030
+
4031
+ if (n_fuse > 1) {
4032
+ ggml_tensor fused_node;
4033
+ memcpy(&fused_node, node, sizeof(ggml_tensor));
4034
+ for (int j = 0; j < n_fuse - 1; ++j) {
4035
+ fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
4036
+ }
4037
+ fused_node.data = cgraph->nodes[i + n_fuse - 1]->data;
4038
+ if (node->op == GGML_OP_ADD) {
4039
+ ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse);
4040
+ } else {
4041
+ ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse);
4042
+ }
4043
+ return n_fuse - 1;
4044
+ }
4045
+ }
4046
+
4047
+ bool fused_mul_mat_vec = false;
4048
+ int fused_node_count = 0;
4049
+
4050
+ // gate + glu + up
4051
+ for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
4052
+ const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
4053
+
4054
+ if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
4055
+ ggml_tensor * glu = cgraph->nodes[i + 4];
4056
+ ggml_tensor * gate_bias_n = glu->src[0];
4057
+ ggml_tensor * up_bias_n = glu->src[1];
4058
+
4059
+ //we don't assume the order for {gate, up}. Instead infer it from the bias tensor
4060
+ ggml_tensor * gate_n = nullptr;
4061
+ ggml_tensor * up_n = nullptr;
4062
+
4063
+ if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
4064
+ gate_n = cgraph->nodes[i];
4065
+ up_n = cgraph->nodes[i + 2];
4066
+ } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
4067
+ gate_n = cgraph->nodes[i + 2];
4068
+ up_n = cgraph->nodes[i];
4069
+ } else {
4070
+ continue;
4071
+ }
4072
+
4073
+ auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
4074
+ if (op_bias == GGML_OP_ADD) {
4075
+ if (bias_node->src[0] == mul_node) {
4076
+ return bias_node->src[1];
4077
+ }
4078
+ if (bias_node->src[1] == mul_node) {
4079
+ return bias_node->src[0];
4080
+ }
4081
+ return (ggml_tensor *) nullptr;
4082
+ }
4083
+ GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
4084
+ GGML_ASSERT(bias_node->src[0] == mul_node);
4085
+ return bias_node->src[1];
4086
+ };
4087
+
4088
+ ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
4089
+ ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
4090
+
4091
+ if (!up_bias_tensor || !gate_bias_tensor) {
4092
+ continue;
4093
+ }
4094
+
4095
+ // we don't support repeating adds
4096
+ if (bias_op == GGML_OP_ADD && (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
4097
+ !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
4098
+ continue;
4099
+ }
4100
+
4101
+ const ggml_tensor * src0 = up_n->src[0];
4102
+ const ggml_tensor * src1 = up_n->src[1];
4103
+ const ggml_tensor * ids = up_n->src[2];
4104
+
4105
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
4106
+ ggml_cuda_mm_fusion_args_host fusion_data{};
4107
+ fusion_data.gate = gate_n->src[0];
4108
+ fusion_data.x_bias = up_bias_tensor;
4109
+ fusion_data.gate_bias = gate_bias_tensor;
4110
+ fusion_data.glu_op = ggml_get_glu_op(glu);
4111
+
4112
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
4113
+ fused_mul_mat_vec = true;
4114
+ fused_node_count = 5;
4115
+ break;
4116
+ }
4117
+
4118
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
4119
+ ggml_cuda_mm_fusion_args_host fusion_data{};
4120
+ fusion_data.gate = gate_n->src[0];
4121
+ fusion_data.x_bias = up_bias_tensor;
4122
+ fusion_data.gate_bias = gate_bias_tensor;
4123
+ fusion_data.glu_op = ggml_get_glu_op(glu);
4124
+
4125
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
4126
+ fused_mul_mat_vec = true;
4127
+ fused_node_count = 5;
4128
+ break;
4129
+ }
4130
+ } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
4131
+ ggml_tensor * glu = cgraph->nodes[i + 2];
4132
+ ggml_tensor * gate = glu->src[0];
4133
+ ggml_tensor * up = glu->src[1];
4134
+
4135
+ bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) ||
4136
+ (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
4137
+
4138
+ if (!ok) {
4139
+ continue;
4140
+ }
4141
+
4142
+ const ggml_tensor * src0 = up->src[0];
4143
+ const ggml_tensor * src1 = up->src[1];
4144
+ const ggml_tensor * ids = up->src[2];
4145
+
4146
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
4147
+ ggml_cuda_mm_fusion_args_host fusion_data{};
4148
+ fusion_data.gate = gate->src[0];
4149
+ fusion_data.glu_op = ggml_get_glu_op(glu);
4150
+
4151
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
4152
+ fused_mul_mat_vec = true;
4153
+ fused_node_count = 3;
4154
+ break;
4155
+ }
4156
+
4157
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
4158
+ ggml_cuda_mm_fusion_args_host fusion_data{};
4159
+ fusion_data.gate = gate->src[0];
4160
+ fusion_data.glu_op = ggml_get_glu_op(glu);
4161
+
4162
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
4163
+ fused_mul_mat_vec = true;
4164
+ fused_node_count = 3;
4165
+ break;
4166
+ }
4167
+ }
4168
+ }
4169
+
4170
+ if (fused_mul_mat_vec) {
4171
+ return fused_node_count - 1;
4172
+ }
4173
+
4174
+ fused_mul_mat_vec = false;
4175
+ fused_node_count = 0;
4176
+
4177
+ // gate + add + glu + up + add
4178
+ for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
4179
+ const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
4180
+
4181
+ if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
4182
+ continue;
4183
+ }
4184
+
4185
+ ggml_tensor * mm_node = cgraph->nodes[i];
4186
+ ggml_tensor * bias_node = cgraph->nodes[i + 1];
4187
+
4188
+ ggml_tensor * bias_tensor = nullptr;
4189
+ if (bias_op == GGML_OP_ADD) {
4190
+ if (bias_node->src[0] == mm_node) {
4191
+ bias_tensor = bias_node->src[1];
4192
+ } else if (bias_node->src[1] == mm_node) {
4193
+ bias_tensor = bias_node->src[0];
4194
+ } else {
4195
+ continue;
4196
+ }
4197
+ } else {
4198
+ if (bias_node->src[0] != mm_node) {
4199
+ continue;
4200
+ }
4201
+ bias_tensor = bias_node->src[1];
4202
+ }
4203
+
4204
+ const ggml_tensor * src0 = mm_node->src[0];
4205
+ const ggml_tensor * src1 = mm_node->src[1];
4206
+ const ggml_tensor * ids = mm_node->src[2];
4207
+
4208
+ if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
4209
+ continue;
4210
+ }
4211
+
4212
+ if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
4213
+ continue;
4214
+ }
4215
+
4216
+ ggml_cuda_mm_fusion_args_host fusion_data{};
4217
+ fusion_data.x_bias = bias_tensor;
4218
+
4219
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
4220
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
4221
+ fused_mul_mat_vec = true;
4222
+ fused_node_count = 2;
4223
+ break;
4224
+ }
4225
+
4226
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
4227
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
4228
+ fused_mul_mat_vec = true;
4229
+ fused_node_count = 2;
4230
+ break;
4231
+ }
4232
+ }
4233
+
4234
+ if (fused_mul_mat_vec) {
4235
+ return fused_node_count - 1;
4236
+ }
4237
+
4238
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD }, {})) {
4239
+ ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]);
4240
+ return 2;
4241
+ }
4242
+
4243
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
4244
+ ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i + 1]);
4245
+ return 1;
4246
+ }
4247
+
4248
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
4249
+ ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]);
4250
+ return 2;
4251
+ }
4252
+
4253
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
4254
+ ggml_cuda_op_ssm_conv(*cuda_ctx, node, /*bias_add_node=*/ nullptr, cgraph->nodes[i + 1]);
4255
+ return 1;
4256
+ }
4257
+
4258
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
4259
+ ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
4260
+ ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
4261
+ ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i + 1]);
4262
+ return 1;
4263
+ }
4264
+
4265
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) {
4266
+ ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i + 1]);
4267
+ return 1;
4268
+ }
4269
+
4270
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
4271
+ ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i + 2], node);
4272
+ return 2;
4273
+ }
4274
+
4275
+ return 0;
4276
+ }
4277
+
4278
+ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
3240
4279
  bool graph_evaluated_or_captured = false;
3241
4280
 
3242
4281
  // flag used to determine whether it is an integrated_gpu
@@ -3378,288 +4417,15 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
3378
4417
  continue;
3379
4418
  }
3380
4419
 
4420
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
4421
+ continue;
4422
+ }
3381
4423
 
3382
- // start of fusion operations
3383
- static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
3384
- if (!disable_fusion) {
3385
-
3386
- if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
3387
- ggml_tensor * weights = cgraph->nodes[i + 9];
3388
- ggml_tensor * selected_experts = cgraph->nodes[i + 3];
3389
- ggml_tensor * clamp = cgraph->nodes[i + 7];
3390
- ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
3391
- /*delayed softmax*/ false, clamp);
3392
- i += 9;
3393
- continue;
3394
- }
3395
-
3396
- if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
3397
- ggml_tensor * weights = cgraph->nodes[i + 4];
3398
- ggml_tensor * selected_experts = cgraph->nodes[i + 3];
3399
- ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
3400
- /*delayed softmax*/ false);
3401
- i += 4;
3402
- continue;
3403
- }
3404
-
3405
- if (ggml_cuda_can_fuse(cgraph, i,
3406
- ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
3407
- ggml_tensor * weights = cgraph->nodes[i + 5];
3408
- ggml_tensor * ids = cgraph->nodes[i + 1];
3409
-
3410
- ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
3411
- /*delayed_softmax*/ true);
3412
- i += 5;
3413
- continue;
3414
- }
3415
-
3416
- if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3417
- ggml_tensor * rope = cgraph->nodes[i];
3418
- ggml_tensor * set_rows = cgraph->nodes[i + 2];
3419
-
3420
- ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
3421
- i += 2;
3422
- continue;
3423
- }
3424
-
3425
- if (node->op == GGML_OP_ADD) {
3426
- int n_fuse = 0;
3427
- ggml_op ops[8];
3428
- std::fill(ops, ops + 8, GGML_OP_ADD);
3429
-
3430
- for (; n_fuse <= 6; ++n_fuse){
3431
- if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
3432
- break;
3433
- }
3434
- if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
3435
- break;
3436
- }
3437
- if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
3438
- break;
3439
- }
3440
- }
3441
-
3442
- n_fuse++;
3443
-
3444
- if (n_fuse > 1) {
3445
- for (int j = 0; j < n_fuse - 1; ++j) {
3446
- node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
3447
- }
3448
- cgraph->nodes[i + n_fuse - 1]->data = node->data;
3449
- ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
3450
- i += n_fuse - 1;
3451
-
3452
- continue;
3453
- }
3454
- }
3455
-
3456
- bool fused_mul_mat_vec = false;
3457
- int fused_node_count = 0;
3458
-
3459
- for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3460
- const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
3461
-
3462
- if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
3463
- ggml_tensor * glu = cgraph->nodes[i + 4];
3464
- ggml_tensor * gate_bias_n = glu->src[0];
3465
- ggml_tensor * up_bias_n = glu->src[1];
3466
-
3467
- //we don't assume the order for {gate, up}. Instead infer it from the bias tensor
3468
- ggml_tensor * gate_n = nullptr;
3469
- ggml_tensor * up_n = nullptr;
3470
-
3471
- if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
3472
- gate_n = cgraph->nodes[i];
3473
- up_n = cgraph->nodes[i + 2];
3474
- } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
3475
- gate_n = cgraph->nodes[i + 2];
3476
- up_n = cgraph->nodes[i];
3477
- } else {
3478
- continue;
3479
- }
3480
-
3481
- auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
3482
- if (op_bias == GGML_OP_ADD) {
3483
- if (bias_node->src[0] == mul_node) {
3484
- return bias_node->src[1];
3485
- }
3486
- if (bias_node->src[1] == mul_node) {
3487
- return bias_node->src[0];
3488
- }
3489
- return (ggml_tensor *) nullptr;
3490
- }
3491
- GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
3492
- GGML_ASSERT(bias_node->src[0] == mul_node);
3493
- return bias_node->src[1];
3494
- };
3495
-
3496
- ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
3497
- ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
3498
-
3499
- if (!up_bias_tensor || !gate_bias_tensor) {
3500
- continue;
3501
- }
3502
-
3503
- // we don't support repeating adds
3504
- if (bias_op == GGML_OP_ADD &&
3505
- (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
3506
- !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
3507
- continue;
3508
- }
3509
-
3510
- const ggml_tensor * src0 = up_n->src[0];
3511
- const ggml_tensor * src1 = up_n->src[1];
3512
- const ggml_tensor * ids = up_n->src[2];
3513
-
3514
- if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
3515
- ggml_cuda_mm_fusion_args_host fusion_data{};
3516
- fusion_data.gate = gate_n->src[0];
3517
- fusion_data.x_bias = up_bias_tensor;
3518
- fusion_data.gate_bias = gate_bias_tensor;
3519
- fusion_data.glu_op = ggml_get_glu_op(glu);
3520
-
3521
- ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3522
- fused_mul_mat_vec = true;
3523
- fused_node_count = 5;
3524
- break;
3525
- }
3526
-
3527
- if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
3528
- ggml_cuda_mm_fusion_args_host fusion_data{};
3529
- fusion_data.gate = gate_n->src[0];
3530
- fusion_data.x_bias = up_bias_tensor;
3531
- fusion_data.gate_bias = gate_bias_tensor;
3532
- fusion_data.glu_op = ggml_get_glu_op(glu);
3533
-
3534
- ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3535
- fused_mul_mat_vec = true;
3536
- fused_node_count = 5;
3537
- break;
3538
- }
3539
- } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
3540
- ggml_tensor * glu = cgraph->nodes[i + 2];
3541
- ggml_tensor * gate = glu->src[0];
3542
- ggml_tensor * up = glu->src[1];
3543
-
3544
- bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
3545
- || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
3546
-
3547
- if (!ok) continue;
3548
-
3549
- const ggml_tensor * src0 = up->src[0];
3550
- const ggml_tensor * src1 = up->src[1];
3551
- const ggml_tensor * ids = up->src[2];
3552
-
3553
- if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
3554
- ggml_cuda_mm_fusion_args_host fusion_data{};
3555
- fusion_data.gate = gate->src[0];
3556
- fusion_data.glu_op = ggml_get_glu_op(glu);
3557
-
3558
- ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3559
- fused_mul_mat_vec = true;
3560
- fused_node_count = 3;
3561
- break;
3562
- }
3563
-
3564
- if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
3565
- ggml_cuda_mm_fusion_args_host fusion_data{};
3566
- fusion_data.gate = gate->src[0];
3567
- fusion_data.glu_op = ggml_get_glu_op(glu);
3568
-
3569
- ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3570
- fused_mul_mat_vec = true;
3571
- fused_node_count = 3;
3572
- break;
3573
- }
3574
- }
3575
- }
3576
-
3577
- if (fused_mul_mat_vec) {
3578
- i += fused_node_count - 1;
3579
- continue;
3580
- }
3581
-
3582
- fused_mul_mat_vec = false;
3583
- fused_node_count = 0;
3584
-
3585
- for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3586
- const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
3587
-
3588
- if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
3589
- continue;
3590
- }
3591
-
3592
- ggml_tensor * mm_node = cgraph->nodes[i];
3593
- ggml_tensor * bias_node = cgraph->nodes[i + 1];
3594
-
3595
- ggml_tensor * bias_tensor = nullptr;
3596
- if (bias_op == GGML_OP_ADD) {
3597
- if (bias_node->src[0] == mm_node) {
3598
- bias_tensor = bias_node->src[1];
3599
- } else if (bias_node->src[1] == mm_node) {
3600
- bias_tensor = bias_node->src[0];
3601
- } else {
3602
- continue;
3603
- }
3604
- } else {
3605
- if (bias_node->src[0] != mm_node) {
3606
- continue;
3607
- }
3608
- bias_tensor = bias_node->src[1];
3609
- }
3610
-
3611
- const ggml_tensor * src0 = mm_node->src[0];
3612
- const ggml_tensor * src1 = mm_node->src[1];
3613
- const ggml_tensor * ids = mm_node->src[2];
3614
-
3615
- if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
3616
- continue;
3617
- }
3618
-
3619
- if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
3620
- continue;
3621
- }
3622
-
3623
- ggml_cuda_mm_fusion_args_host fusion_data{};
3624
- fusion_data.x_bias = bias_tensor;
3625
-
3626
- if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
3627
- ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
3628
- fused_mul_mat_vec = true;
3629
- fused_node_count = 2;
3630
- break;
3631
- }
3632
-
3633
- if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
3634
- ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
3635
- fused_mul_mat_vec = true;
3636
- fused_node_count = 2;
3637
- break;
3638
- }
3639
- }
3640
-
3641
- if (fused_mul_mat_vec) {
3642
- i += fused_node_count - 1;
3643
- continue;
3644
- }
3645
-
3646
- if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
3647
- ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
3648
- i += 2;
3649
- continue;
3650
- }
3651
-
3652
- if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
3653
- ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
3654
- i++;
3655
- continue;
3656
- }
4424
+ int nodes_to_skip = ggml_cuda_try_fuse(cuda_ctx, cgraph, i);
3657
4425
 
3658
- if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
3659
- i += 2;
3660
- ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
3661
- continue;
3662
- }
4426
+ if (nodes_to_skip != 0) {
4427
+ i += nodes_to_skip;
4428
+ continue;
3663
4429
  }
3664
4430
  #ifndef NDEBUG
3665
4431
  assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
@@ -3687,13 +4453,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
3687
4453
  }
3688
4454
 
3689
4455
  #ifdef USE_CUDA_GRAPH
4456
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3690
4457
  if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
3691
- if (cuda_ctx->cuda_graph->graph != nullptr) {
3692
- CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
3693
- cuda_ctx->cuda_graph->graph = nullptr;
4458
+ if (graph->graph != nullptr) {
4459
+ CUDA_CHECK(cudaGraphDestroy(graph->graph));
4460
+ graph->graph = nullptr;
3694
4461
  }
3695
4462
 
3696
- CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
4463
+ CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));
3697
4464
  graph_evaluated_or_captured = true; // CUDA graph has been captured
3698
4465
 
3699
4466
  std::lock_guard<std::mutex> lock(ggml_cuda_lock);
@@ -3706,41 +4473,38 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
3706
4473
  }
3707
4474
 
3708
4475
  if (use_cuda_graph) {
3709
- if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
3710
- CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
4476
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
4477
+ if (graph->instance == nullptr) { // Create executable graph from captured graph.
4478
+ CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
3711
4479
  }
3712
4480
  if (cuda_graph_update_required) { // Update graph executable
3713
- ggml_cuda_graph_update_executable(cuda_ctx);
4481
+ ggml_cuda_graph_update_executable(cuda_ctx, graph_key);
3714
4482
  }
3715
4483
  // Launch graph
3716
- CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
4484
+ CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
3717
4485
  #else
4486
+ GGML_UNUSED(graph_key);
3718
4487
  graph_evaluated_or_captured = true;
3719
4488
  #endif // USE_CUDA_GRAPH
3720
4489
  }
3721
4490
  }
3722
4491
 
3723
- static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
3724
-
3725
4492
  #ifdef USE_CUDA_GRAPH
4493
+ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
4494
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3726
4495
 
3727
- if (cuda_ctx->cuda_graph == nullptr) {
3728
- cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
3729
- }
3730
-
3731
- if (cuda_ctx->cuda_graph->graph == nullptr) {
4496
+ if (graph->graph == nullptr) {
3732
4497
  if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
3733
- cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
3734
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
4498
+ if (!graph->disable_due_to_gpu_arch) {
4499
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
4500
+ }
4501
+ graph->disable_due_to_gpu_arch = true;
3735
4502
  }
3736
4503
  }
3737
4504
 
3738
- return cuda_ctx->cuda_graph->is_enabled();
3739
- #else
3740
- GGML_UNUSED(cuda_ctx);
3741
- return false;
3742
- #endif // USE_CUDA_GRAPH
4505
+ return graph->is_enabled();
3743
4506
  }
4507
+ #endif // USE_CUDA_GRAPH
3744
4508
 
3745
4509
  static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
3746
4510
  ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
@@ -3749,15 +4513,40 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
3749
4513
 
3750
4514
  bool use_cuda_graph = false;
3751
4515
  bool cuda_graph_update_required = false;
4516
+ const void * graph_key = nullptr;
3752
4517
 
3753
4518
  #ifdef USE_CUDA_GRAPH
3754
- use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
3755
-
3756
- if (cuda_ctx->cuda_graph->is_enabled()) {
3757
- cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
3758
- use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
3759
-
3760
- cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
4519
+ graph_key = ggml_cuda_graph_get_key(cgraph);
4520
+
4521
+ ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
4522
+
4523
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
4524
+ if (graph->is_enabled()) {
4525
+ const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph);
4526
+ if (graph_compatible) {
4527
+ const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
4528
+
4529
+ if (!graph->warmup_complete) {
4530
+ // Warmup: need at least 2 calls with no property change on the 2nd call
4531
+ if (!properties_changed) {
4532
+ graph->warmup_complete = true;
4533
+ GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__);
4534
+ use_cuda_graph = true;
4535
+ cuda_graph_update_required = true;
4536
+ }
4537
+ // else: properties changed or first call - execute directly (use_cuda_graph stays false)
4538
+ } else {
4539
+ // Post-warmup: normal CUDA graph operation
4540
+ if (properties_changed) {
4541
+ // Properties changed - reset warmup, execute directly until stable again
4542
+ graph->warmup_complete = false;
4543
+ GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__);
4544
+ } else {
4545
+ use_cuda_graph = true;
4546
+ cuda_graph_update_required = graph->instance == nullptr;
4547
+ }
4548
+ }
4549
+ }
3761
4550
  }
3762
4551
  #endif // USE_CUDA_GRAPH
3763
4552
 
@@ -3771,7 +4560,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
3771
4560
  CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
3772
4561
  }
3773
4562
 
3774
- ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
4563
+ ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);
3775
4564
 
3776
4565
  return GGML_STATUS_SUCCESS;
3777
4566
  }
@@ -3804,7 +4593,14 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
3804
4593
  static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
3805
4594
  ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
3806
4595
 
3807
- const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
4596
+ #ifdef USE_CUDA_GRAPH
4597
+ const void * graph_key = ggml_cuda_graph_get_key(cgraph);
4598
+ const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
4599
+ #else
4600
+ const bool use_cuda_graph = false;
4601
+ GGML_UNUSED(cuda_ctx);
4602
+ GGML_UNUSED(cgraph);
4603
+ #endif
3808
4604
 
3809
4605
  static bool enable_graph_optimization = [] {
3810
4606
  const char * env = getenv("GGML_CUDA_GRAPH_OPT");
@@ -4043,6 +4839,8 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
4043
4839
  /* .free = */ ggml_backend_cuda_free,
4044
4840
  /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
4045
4841
  /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
4842
+ /* .set_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async,
4843
+ /* .get_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async,
4046
4844
  /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async,
4047
4845
  /* .synchronize = */ ggml_backend_cuda_synchronize,
4048
4846
  /* .graph_plan_create = */ NULL,
@@ -4118,14 +4916,6 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
4118
4916
 
4119
4917
  // backend device
4120
4918
 
4121
- struct ggml_backend_cuda_device_context {
4122
- int device;
4123
- std::string name;
4124
- std::string description;
4125
- std::string pci_bus_id;
4126
- int op_offload_min_batch_size;
4127
- };
4128
-
4129
4919
  static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
4130
4920
  ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
4131
4921
  return ctx->name.c_str();
@@ -4214,6 +5004,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k
4214
5004
 
4215
5005
  static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
4216
5006
  ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
5007
+
5008
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
5009
+ std::lock_guard<std::mutex> lock(ctx->device_mutex);
5010
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
5011
+
4217
5012
  ggml_cuda_set_device(ctx->device);
4218
5013
  CUDA_CHECK(cudaMemGetInfo(free, total));
4219
5014
 
@@ -4240,11 +5035,24 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
4240
5035
  }
4241
5036
  #endif // defined(__linux__)
4242
5037
 
5038
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
5039
+ // If no backends or buffers are active, the cudaMemGetInfo call above lazily created a CUDA
5040
+ // context that permanently consumes VRAM. Reset the device to free it.
5041
+ if (ctx->active_count == 0) {
5042
+ CUDA_CHECK(cudaDeviceReset());
5043
+ }
5044
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
4243
5045
  }
4244
5046
 
4245
5047
  static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
4246
- GGML_UNUSED(dev);
4247
- return GGML_BACKEND_DEVICE_TYPE_GPU;
5048
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *) dev->context;
5049
+
5050
+ cudaDeviceProp prop;
5051
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));
5052
+
5053
+ return prop.integrated
5054
+ ? GGML_BACKEND_DEVICE_TYPE_IGPU
5055
+ : GGML_BACKEND_DEVICE_TYPE_GPU;
4248
5056
  }
4249
5057
 
4250
5058
  static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
@@ -4335,6 +5143,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4335
5143
  case GGML_UNARY_OP_CEIL:
4336
5144
  case GGML_UNARY_OP_ROUND:
4337
5145
  case GGML_UNARY_OP_TRUNC:
5146
+ // TODO: should become:
5147
+ //return ggml_is_contiguous_rows(op->src[0]);
4338
5148
  return ggml_is_contiguous(op->src[0]);
4339
5149
  default:
4340
5150
  return false;
@@ -4391,12 +5201,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4391
5201
  switch (a->type) {
4392
5202
  case GGML_TYPE_F32:
4393
5203
  case GGML_TYPE_F16:
5204
+ case GGML_TYPE_Q1_0:
4394
5205
  case GGML_TYPE_Q4_0:
4395
5206
  case GGML_TYPE_Q4_1:
4396
5207
  case GGML_TYPE_Q5_0:
4397
5208
  case GGML_TYPE_Q5_1:
4398
5209
  case GGML_TYPE_Q8_0:
4399
5210
  case GGML_TYPE_MXFP4:
5211
+ case GGML_TYPE_NVFP4:
4400
5212
  case GGML_TYPE_Q2_K:
4401
5213
  case GGML_TYPE_Q3_K:
4402
5214
  case GGML_TYPE_Q4_K:
@@ -4427,6 +5239,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4427
5239
  case GGML_TYPE_F32:
4428
5240
  case GGML_TYPE_BF16:
4429
5241
  case GGML_TYPE_I32:
5242
+ case GGML_TYPE_Q1_0:
4430
5243
  case GGML_TYPE_Q4_0:
4431
5244
  case GGML_TYPE_Q4_1:
4432
5245
  case GGML_TYPE_Q5_0:
@@ -4532,7 +5345,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4532
5345
  case GGML_OP_CONCAT:
4533
5346
  {
4534
5347
  ggml_type src0_type = op->src[0]->type;
4535
- return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
5348
+ ggml_type src1_type = op->src[1]->type;
5349
+ return src0_type == src1_type &&
5350
+ src0_type == op->type &&
5351
+ !ggml_is_quantized(src0_type) &&
5352
+ ggml_blck_size(src0_type) == 1 &&
5353
+ (ggml_type_size(src0_type) == 1 ||
5354
+ ggml_type_size(src0_type) == 2 ||
5355
+ ggml_type_size(src0_type) == 4 ||
5356
+ ggml_type_size(src0_type) == 8);
4536
5357
  } break;
4537
5358
  case GGML_OP_CONV_TRANSPOSE_1D:
4538
5359
  {
@@ -4551,19 +5372,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4551
5372
  case GGML_OP_L2_NORM:
4552
5373
  return true;
4553
5374
  case GGML_OP_RMS_NORM_BACK:
4554
- return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
5375
+ return ggml_is_contiguous(op->src[0]);
4555
5376
  break;
4556
5377
  case GGML_OP_NONE:
4557
5378
  case GGML_OP_RESHAPE:
4558
5379
  case GGML_OP_VIEW:
4559
5380
  case GGML_OP_PERMUTE:
4560
5381
  case GGML_OP_TRANSPOSE:
4561
- case GGML_OP_ADD:
4562
5382
  case GGML_OP_ADD_ID:
4563
5383
  case GGML_OP_ADD1:
4564
- case GGML_OP_SUB:
4565
- case GGML_OP_MUL:
4566
- case GGML_OP_DIV:
4567
5384
  case GGML_OP_SCALE:
4568
5385
  case GGML_OP_SQR:
4569
5386
  case GGML_OP_SQRT:
@@ -4572,6 +5389,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4572
5389
  case GGML_OP_CLAMP:
4573
5390
  case GGML_OP_LOG:
4574
5391
  return true;
5392
+ case GGML_OP_ADD:
5393
+ case GGML_OP_SUB:
5394
+ case GGML_OP_MUL:
5395
+ case GGML_OP_DIV:
5396
+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
5397
+ (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
5398
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
4575
5399
  case GGML_OP_SSM_SCAN: {
4576
5400
  if (op->src[3]->ne[0] == 1) {
4577
5401
  // Mamba2
@@ -4613,8 +5437,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4613
5437
  case GGML_OP_CONV_2D_DW:
4614
5438
  case GGML_OP_CONV_TRANSPOSE_2D:
4615
5439
  case GGML_OP_POOL_2D:
4616
- case GGML_OP_ACC:
4617
5440
  return true;
5441
+ case GGML_OP_ACC:
5442
+ // TODO: extend support like so:
5443
+ //return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);
5444
+ return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
4618
5445
  case GGML_OP_SUM:
4619
5446
  return ggml_is_contiguous_rows(op->src[0]);
4620
5447
  case GGML_OP_TOP_K:
@@ -4627,8 +5454,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4627
5454
  case GGML_OP_SUM_ROWS:
4628
5455
  case GGML_OP_MEAN:
4629
5456
  case GGML_OP_GROUP_NORM:
4630
- case GGML_OP_PAD:
4631
5457
  return ggml_is_contiguous(op->src[0]);
5458
+ case GGML_OP_PAD:
5459
+ return true;
4632
5460
  case GGML_OP_UPSCALE:
4633
5461
  case GGML_OP_PAD_REFLECT_1D:
4634
5462
  case GGML_OP_ARANGE:
@@ -4638,6 +5466,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
4638
5466
  case GGML_OP_GATED_LINEAR_ATTN:
4639
5467
  case GGML_OP_RWKV_WKV7:
4640
5468
  return true;
5469
+ case GGML_OP_GATED_DELTA_NET:
5470
+ //TODO: enable once MUSA compiler is solved https://github.com/ggml-org/llama.cpp/pull/19504#issuecomment-4018634327
5471
+ #ifdef GGML_USE_MUSA
5472
+ return false;
5473
+ #else
5474
+ return true;
5475
+ #endif // GGML_USE_MUSA
4641
5476
  case GGML_OP_FLASH_ATTN_EXT:
4642
5477
  return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
4643
5478
  case GGML_OP_CROSS_ENTROPY_LOSS:
@@ -4816,6 +5651,15 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
4816
5651
 
4817
5652
  static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
4818
5653
  GGML_UNUSED(reg);
5654
+ if (strcmp(name, "ggml_backend_comm_init") == 0) {
5655
+ return (void *)ggml_backend_cuda_comm_init;
5656
+ }
5657
+ if (strcmp(name, "ggml_backend_comm_free") == 0) {
5658
+ return (void *)ggml_backend_cuda_comm_free;
5659
+ }
5660
+ if (strcmp(name, "ggml_backend_comm_allreduce_tensor") == 0) {
5661
+ return (void *)ggml_backend_cuda_comm_allreduce_tensor;
5662
+ }
4819
5663
  if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
4820
5664
  return (void *)ggml_backend_cuda_split_buffer_type;
4821
5665
  }
@@ -4859,9 +5703,12 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
4859
5703
  CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
4860
5704
  dev_ctx->description = prop.name;
4861
5705
 
4862
- char pci_bus_id[16] = {};
4863
- snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
5706
+ char pci_bus_id[32] = {};
5707
+ CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i));
4864
5708
  dev_ctx->pci_bus_id = pci_bus_id;
5709
+ for (char & c : dev_ctx->pci_bus_id) {
5710
+ c = std::tolower(c);
5711
+ }
4865
5712
  dev_ctx->op_offload_min_batch_size = min_batch_size;
4866
5713
 
4867
5714
  ggml_backend_dev_t dev = new ggml_backend_device {
@@ -4897,13 +5744,21 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
4897
5744
  return nullptr;
4898
5745
  }
4899
5746
 
5747
+ ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device);
5748
+
4900
5749
  ggml_backend_t cuda_backend = new ggml_backend {
4901
5750
  /* .guid = */ ggml_backend_cuda_guid(),
4902
5751
  /* .iface = */ ggml_backend_cuda_interface,
4903
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
5752
+ /* .device = */ dev,
4904
5753
  /* .context = */ ctx,
4905
5754
  };
4906
5755
 
5756
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
5757
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
5758
+ std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
5759
+ dev_ctx->active_count++;
5760
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
5761
+
4907
5762
  return cuda_backend;
4908
5763
  }
4909
5764