whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -4,6 +4,7 @@
4
4
 
5
5
  #include "ggml-cuda/common.cuh"
6
6
  #include "ggml-cuda/acc.cuh"
7
+ #include "ggml-cuda/add-id.cuh"
7
8
  #include "ggml-cuda/arange.cuh"
8
9
  #include "ggml-cuda/argmax.cuh"
9
10
  #include "ggml-cuda/argsort.cuh"
@@ -11,6 +12,9 @@
11
12
  #include "ggml-cuda/clamp.cuh"
12
13
  #include "ggml-cuda/concat.cuh"
13
14
  #include "ggml-cuda/conv-transpose-1d.cuh"
15
+ #include "ggml-cuda/conv2d.cuh"
16
+ #include "ggml-cuda/conv2d-dw.cuh"
17
+ #include "ggml-cuda/conv2d-transpose.cuh"
14
18
  #include "ggml-cuda/convert.cuh"
15
19
  #include "ggml-cuda/count-equal.cuh"
16
20
  #include "ggml-cuda/cpy.cuh"
@@ -19,27 +23,35 @@
19
23
  #include "ggml-cuda/fattn.cuh"
20
24
  #include "ggml-cuda/getrows.cuh"
21
25
  #include "ggml-cuda/im2col.cuh"
26
+ #include "ggml-cuda/mmf.cuh"
22
27
  #include "ggml-cuda/mmq.cuh"
23
- #include "ggml-cuda/mmv.cuh"
28
+ #include "ggml-cuda/mmvf.cuh"
24
29
  #include "ggml-cuda/mmvq.cuh"
25
30
  #include "ggml-cuda/norm.cuh"
26
31
  #include "ggml-cuda/opt-step-adamw.cuh"
32
+ #include "ggml-cuda/opt-step-sgd.cuh"
27
33
  #include "ggml-cuda/out-prod.cuh"
28
34
  #include "ggml-cuda/pad.cuh"
29
35
  #include "ggml-cuda/pool2d.cuh"
30
36
  #include "ggml-cuda/quantize.cuh"
31
37
  #include "ggml-cuda/rope.cuh"
38
+ #include "ggml-cuda/roll.cuh"
32
39
  #include "ggml-cuda/scale.cuh"
40
+ #include "ggml-cuda/softcap.cuh"
33
41
  #include "ggml-cuda/softmax.cuh"
34
42
  #include "ggml-cuda/ssm-conv.cuh"
35
43
  #include "ggml-cuda/ssm-scan.cuh"
36
44
  #include "ggml-cuda/sum.cuh"
37
45
  #include "ggml-cuda/sumrows.cuh"
46
+ #include "ggml-cuda/mean.cuh"
38
47
  #include "ggml-cuda/tsembd.cuh"
48
+ #include "ggml-cuda/topk-moe.cuh"
39
49
  #include "ggml-cuda/unary.cuh"
40
50
  #include "ggml-cuda/upscale.cuh"
41
51
  #include "ggml-cuda/wkv.cuh"
42
52
  #include "ggml-cuda/gla.cuh"
53
+ #include "ggml-cuda/set-rows.cuh"
54
+ #include "ggml-cuda/pad_reflect_1d.cuh"
43
55
  #include "ggml.h"
44
56
 
45
57
  #include <algorithm>
@@ -47,16 +59,17 @@
47
59
  #include <atomic>
48
60
  #include <charconv>
49
61
  #include <cinttypes>
62
+ #include <condition_variable>
50
63
  #include <cstddef>
51
64
  #include <cstdint>
52
65
  #include <float.h>
66
+ #include <initializer_list>
53
67
  #include <limits>
54
68
  #include <map>
55
69
  #include <memory>
56
70
  #include <mutex>
57
- #include <stdint.h>
58
- #include <stdio.h>
59
71
  #include <stdarg.h>
72
+ #include <stdio.h>
60
73
  #include <stdlib.h>
61
74
  #include <string>
62
75
  #include <vector>
@@ -97,8 +110,7 @@ int ggml_cuda_get_device() {
97
110
  static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
98
111
  ggml_cuda_set_device(device);
99
112
  cudaError_t err;
100
- if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
101
- {
113
+ if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
102
114
  err = cudaMallocManaged(ptr, size);
103
115
  #if defined(GGML_USE_HIP)
104
116
  if (err == hipSuccess) {
@@ -116,15 +128,13 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
116
128
  err = cudaMalloc(ptr, size);
117
129
  }
118
130
  #endif // defined(GGML_USE_HIP)
119
- }
120
- else
121
- {
131
+ } else {
122
132
  err = cudaMalloc(ptr, size);
123
133
  }
124
134
  return err;
125
135
  }
126
136
 
127
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
137
+ #if defined(GGML_USE_HIP)
128
138
  static int ggml_cuda_parse_id(char devName[]) {
129
139
  // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
130
140
  // these values are not stable so this is susceptible to breakage
@@ -171,33 +181,9 @@ static int ggml_cuda_parse_id(char devName[]) {
171
181
  archNum += archMinor;
172
182
  return archNum;
173
183
  }
174
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
184
+ #endif // defined(GGML_USE_HIP)
175
185
 
176
186
  static ggml_cuda_device_info ggml_cuda_init() {
177
- #ifdef __HIP_PLATFORM_AMD__
178
- // Workaround for a rocBLAS bug when using multiple graphics cards:
179
- // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
180
- {
181
- int major_version = 0;
182
- size_t version_length = 0;
183
- if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
184
- std::vector<char> version(version_length+1, '\0');
185
- if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
186
- version.resize(::strlen(version.data()));
187
- int parsed_value = 0;
188
- if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) {
189
- major_version = parsed_value;
190
- }
191
- }
192
- }
193
- if (major_version < 4) {
194
- GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n");
195
- rocblas_initialize();
196
- CUDA_CHECK(cudaDeviceSynchronize());
197
- }
198
- }
199
- #endif
200
-
201
187
  ggml_cuda_device_info info = {};
202
188
 
203
189
  cudaError_t err = cudaGetDeviceCount(&info.device_count);
@@ -220,6 +206,8 @@ static ggml_cuda_device_info ggml_cuda_init() {
220
206
  GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
221
207
  #endif // GGML_CUDA_FORCE_CUBLAS
222
208
  GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
209
+
210
+ std::vector<std::pair<int, std::string>> turing_devices_without_mma;
223
211
  for (int id = 0; id < info.device_count; ++id) {
224
212
  int device_vmm = 0;
225
213
 
@@ -243,11 +231,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
243
231
 
244
232
  info.default_tensor_split[id] = total_vram;
245
233
  total_vram += prop.totalGlobalMem;
246
-
247
- info.devices[id].nsm = prop.multiProcessorCount;
248
- info.devices[id].smpb = prop.sharedMemPerBlock;
249
- info.devices[id].warp_size = prop.warpSize;
250
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
234
+ info.devices[id].integrated = prop.integrated;
235
+ info.devices[id].nsm = prop.multiProcessorCount;
236
+ info.devices[id].smpb = prop.sharedMemPerBlock;
237
+ info.devices[id].warp_size = prop.warpSize;
238
+ #if defined(GGML_USE_HIP)
251
239
  info.devices[id].smpbo = prop.sharedMemPerBlock;
252
240
 
253
241
  info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
@@ -277,7 +265,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
277
265
  info.devices[id].cc = 100*prop.major + 10*prop.minor;
278
266
  GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
279
267
  id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
280
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
268
+ std::string device_name(prop.name);
269
+ if (device_name == "NVIDIA GeForce MX450") {
270
+ turing_devices_without_mma.push_back({ id, device_name });
271
+ } else if (device_name == "NVIDIA GeForce MX550") {
272
+ turing_devices_without_mma.push_back({ id, device_name });
273
+ } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
274
+ turing_devices_without_mma.push_back({ id, device_name });
275
+ }
276
+ #endif // defined(GGML_USE_HIP)
277
+ }
278
+
279
+ if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) {
280
+ GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n");
281
+ for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) {
282
+ GGML_LOG_INFO(
283
+ " Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str());
284
+ }
285
+ GGML_LOG_INFO(
286
+ "Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n");
281
287
  }
282
288
 
283
289
  for (int id = 0; id < info.device_count; ++id) {
@@ -514,6 +520,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
514
520
  return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
515
521
  }
516
522
 
523
+ // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
524
+ // this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
525
+
526
+ static std::mutex ggml_cuda_lock;
527
+ static std::condition_variable ggml_cuda_lock_cv;
528
+ static std::atomic<int> ggml_cuda_lock_counter;
529
+
530
+ ggml_backend_cuda_context::~ggml_backend_cuda_context() {
531
+ std::unique_lock<std::mutex> lock(ggml_cuda_lock);
532
+ ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
533
+
534
+ if (copy_event != nullptr) {
535
+ CUDA_CHECK(cudaEventDestroy(copy_event));
536
+ }
537
+ for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
538
+ for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
539
+ if (streams[i][j] != nullptr) {
540
+ CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
541
+ }
542
+ }
543
+ if (cublas_handles[i] != nullptr) {
544
+ CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
545
+ }
546
+ }
547
+ }
548
+
549
+
517
550
  // cuda buffer
518
551
 
519
552
  struct ggml_backend_cuda_buffer_context {
@@ -615,9 +648,8 @@ static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
615
648
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
616
649
 
617
650
  ggml_cuda_set_device(ctx->device);
618
- CUDA_CHECK(cudaDeviceSynchronize());
619
- CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size));
620
- CUDA_CHECK(cudaDeviceSynchronize());
651
+ CUDA_CHECK(cudaMemsetAsync(ctx->dev_ptr, value, buffer->size, cudaStreamPerThread));
652
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
621
653
  }
622
654
 
623
655
  static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
@@ -1065,6 +1097,10 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_
1065
1097
  GGML_UNUSED(buft);
1066
1098
  }
1067
1099
 
1100
+ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
1101
+ return buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
1102
+ }
1103
+
1068
1104
  static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1069
1105
  CUDA_CHECK(cudaFreeHost(buffer->context));
1070
1106
  }
@@ -1140,7 +1176,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)(
1140
1176
  static cudaError_t ggml_cuda_cpy_tensor_2d(
1141
1177
  void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
1142
1178
 
1143
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
1144
1179
  const char * src_ptr = (const char *) src->data;
1145
1180
  char * dst_ptr = (char *) dst;
1146
1181
 
@@ -1198,9 +1233,12 @@ static void ggml_cuda_op_mul_mat_cublas(
1198
1233
 
1199
1234
  const int cc = ggml_cuda_info().devices[id].cc;
1200
1235
 
1236
+ const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
1237
+ (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
1238
+
1201
1239
  const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
1202
1240
 
1203
- if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1241
+ if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1204
1242
  ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
1205
1243
  if (src1->type != GGML_TYPE_BF16) {
1206
1244
  const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1228,7 +1266,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1228
1266
 
1229
1267
  const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
1230
1268
  to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1231
- } else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
1269
+ } else if (fast_fp16_hardware_available(cc) && use_fp16) {
1232
1270
  // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1233
1271
  ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
1234
1272
  if (src0->type != GGML_TYPE_F16) {
@@ -1313,9 +1351,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1313
1351
  &beta, dst_dd_i, ldc));
1314
1352
  }
1315
1353
 
1316
- GGML_UNUSED(dst);
1317
- GGML_UNUSED(src1_ddq_i);
1318
- GGML_UNUSED(src1_padded_row_size);
1354
+ GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size);
1319
1355
  }
1320
1356
 
1321
1357
  static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
@@ -1423,8 +1459,6 @@ static void ggml_cuda_op_mul_mat(
1423
1459
  const int64_t nb2 = dst->nb[2];
1424
1460
  const int64_t nb3 = dst->nb[3];
1425
1461
 
1426
- GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
1427
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
1428
1462
  ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
1429
1463
  ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
1430
1464
 
@@ -1719,7 +1753,7 @@ static void ggml_cuda_op_mul_mat(
1719
1753
  }
1720
1754
 
1721
1755
  static __global__ void k_compute_batched_ptrs(
1722
- const half * src0_as_f16, const half * src1_as_f16, char * dst,
1756
+ const void * src0_as_f16, const void * src1_as_f16, char * dst,
1723
1757
  const void ** ptrs_src, void ** ptrs_dst,
1724
1758
  int64_t ne12, int64_t ne13,
1725
1759
  int64_t ne23,
@@ -1742,83 +1776,136 @@ static __global__ void k_compute_batched_ptrs(
1742
1776
  ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
1743
1777
  }
1744
1778
 
1745
- static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1779
+ // Type traits for mapping ggml types to CUDA/cuBLAS types
1780
+ template<ggml_type T>
1781
+ struct batched_mul_mat_traits;
1782
+
1783
+ template<>
1784
+ struct batched_mul_mat_traits<GGML_TYPE_F32> {
1785
+ using cuda_type = float;
1786
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1787
+ static inline const cudaDataType_t data_type = CUDA_R_32F;
1788
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
1789
+ static inline const float alpha = 1.0f;
1790
+ static inline const float beta = 0.0f;
1791
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
1792
+ static inline const void* get_beta() { static const float val = beta; return &val; }
1793
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
1794
+ };
1795
+
1796
+ template<>
1797
+ struct batched_mul_mat_traits<GGML_TYPE_BF16> {
1798
+ using cuda_type = nv_bfloat16;
1799
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1800
+ static inline const cudaDataType_t data_type = CUDA_R_16BF;
1801
+ static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
1802
+ static inline const float alpha = 1.0f;
1803
+ static inline const float beta = 0.0f;
1804
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
1805
+ static inline const void* get_beta() { static const float val = beta; return &val; }
1806
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
1807
+ };
1808
+
1809
+ template<>
1810
+ struct batched_mul_mat_traits<GGML_TYPE_F16> {
1811
+ using cuda_type = half;
1812
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
1813
+ static inline const cudaDataType_t data_type = CUDA_R_16F;
1814
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
1815
+ static inline const half alpha = 1.0;
1816
+ static inline const half beta = 0.0;
1817
+ static inline const void* get_alpha() { static const half val = alpha; return &val; }
1818
+ static inline const void* get_beta() { static const half val = beta; return &val; }
1819
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
1820
+ };
1821
+
1822
+ template<ggml_type src0_type>
1823
+ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1824
+ using traits = batched_mul_mat_traits<src0_type>;
1825
+ using cuda_t = typename traits::cuda_type;
1826
+
1746
1827
  GGML_ASSERT(!ggml_is_transposed(src0));
1747
1828
  GGML_ASSERT(!ggml_is_transposed(src1));
1748
-
1749
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
1750
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
1829
+ GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1830
+ GGML_ASSERT(src0->type == src0_type);
1831
+ GGML_ASSERT(ggml_is_contiguous(dst));
1751
1832
 
1752
1833
  // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
1753
1834
  // As long as dst is contiguous this does not matter though.
1754
- GGML_ASSERT(ggml_is_contiguous(dst));
1755
1835
 
1756
1836
  GGML_TENSOR_BINARY_OP_LOCALS
1757
1837
 
1758
1838
  const int64_t ne_dst = ggml_nelements(dst);
1759
-
1760
1839
  cudaStream_t main_stream = ctx.stream();
1761
-
1762
1840
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
1763
1841
 
1764
- const half * src0_f16 = (const half *) src0->data;
1765
1842
  float * dst_ddf = (float *) dst->data;
1766
-
1767
- const half * src1_f16 = (const half *) src1->data;
1768
1843
  const size_t ts_src1 = ggml_type_size(src1->type);
1769
1844
  GGML_ASSERT(nb10 == ts_src1);
1770
1845
  int64_t s11 = nb11 / ts_src1;
1771
1846
  int64_t s12 = nb12 / ts_src1;
1772
1847
  int64_t s13 = nb13 / ts_src1;
1773
- ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
1774
1848
 
1775
- // convert src1 to fp16
1776
- if (src1->type != GGML_TYPE_F16) {
1777
- const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
1778
- const int64_t ne_src1 = ggml_nelements(src1);
1779
- src1_f16_alloc.alloc(ne_src1);
1780
- GGML_ASSERT(to_fp16_cuda != nullptr);
1849
+ const cuda_t * src0_ptr = nullptr;
1850
+ const cuda_t * src1_ptr = nullptr;
1781
1851
 
1782
- to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1852
+ ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
1853
+ ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
1783
1854
 
1784
- src1_f16 = src1_f16_alloc.get();
1855
+ bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
1856
+ bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
1857
+
1858
+ // Handle src0
1859
+ src0_ptr = (const cuda_t *) src0->data;
1860
+
1861
+ // Handle src1 - convert if necessary
1862
+ if (src1->type == src0_type) {
1863
+ src1_ptr = (const cuda_t *) src1->data;
1864
+ } else {
1865
+ // Convert src1 to target type using traits conversion functions
1866
+ const int64_t ne_src1 = ggml_nelements(src1);
1867
+ src1_alloc.alloc(ne_src1);
1868
+
1869
+ const auto convert_func = traits::get_nc_converter(src1->type);
1870
+ GGML_ASSERT(convert_func != nullptr);
1871
+ convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1872
+ src1_ptr = src1_alloc.get();
1785
1873
  s11 = ne10;
1786
1874
  s12 = ne11*s11;
1787
1875
  s13 = ne12*s12;
1876
+
1877
+ is_src1_cont_2 = true;
1788
1878
  }
1789
1879
 
1790
- ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
1880
+ // Setup destination buffer
1881
+ ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
1791
1882
  char * dst_t;
1792
-
1793
- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1794
- cudaDataType_t cu_data_type = CUDA_R_16F;
1795
-
1796
- // dst strides
1797
1883
  size_t nbd2 = dst->nb[2];
1798
1884
  size_t nbd3 = dst->nb[3];
1799
1885
 
1800
- const half alpha_f16 = 1.0f;
1801
- const half beta_f16 = 0.0f;
1802
-
1886
+ cublasComputeType_t cu_compute_type = traits::compute_type;
1887
+ cudaDataType_t cu_data_type = traits::data_type;
1888
+ cudaDataType_t cu_data_type_a = traits::data_type;
1889
+ cudaDataType_t cu_data_type_b = traits::data_type;
1890
+ const void * alpha = traits::get_alpha();
1891
+ const void * beta = traits::get_beta();
1803
1892
  const float alpha_f32 = 1.0f;
1804
- const float beta_f32 = 0.0f;
1805
-
1806
- const void * alpha = &alpha_f16;
1807
- const void * beta = &beta_f16;
1893
+ const float beta_f32 = 0.0f;
1808
1894
 
1809
1895
  if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1810
- dst_t = (char *) dst_f16.alloc(ne_dst);
1811
-
1812
- nbd2 /= sizeof(float) / sizeof(half);
1813
- nbd3 /= sizeof(float) / sizeof(half);
1896
+ if constexpr (src0_type == GGML_TYPE_F32) {
1897
+ dst_t = (char *) dst_ddf; // Direct F32 output
1898
+ } else {
1899
+ dst_t = (char *) dst_temp.alloc(ne_dst);
1900
+ nbd2 /= sizeof(float) / sizeof(cuda_t);
1901
+ nbd3 /= sizeof(float) / sizeof(cuda_t);
1902
+ }
1814
1903
  } else {
1815
1904
  dst_t = (char *) dst_ddf;
1816
-
1817
1905
  cu_compute_type = CUBLAS_COMPUTE_32F;
1818
- cu_data_type = CUDA_R_32F;
1819
-
1906
+ cu_data_type = CUDA_R_32F;
1820
1907
  alpha = &alpha_f32;
1821
- beta = &beta_f32;
1908
+ beta = &beta_f32;
1822
1909
  }
1823
1910
 
1824
1911
  int id = ggml_cuda_get_device();
@@ -1826,7 +1913,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1826
1913
  if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1827
1914
  cu_compute_type = CUBLAS_COMPUTE_32F;
1828
1915
  alpha = &alpha_f32;
1829
- beta = &beta_f32;
1916
+ beta = &beta_f32;
1830
1917
  }
1831
1918
 
1832
1919
  GGML_ASSERT(ne12 % ne02 == 0);
@@ -1836,35 +1923,19 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1836
1923
  const int64_t r2 = ne12/ne02;
1837
1924
  const int64_t r3 = ne13/ne03;
1838
1925
 
1839
- #if 0
1840
- // use cublasGemmEx
1841
- {
1842
- for (int i13 = 0; i13 < ne13; ++i13) {
1843
- for (int i12 = 0; i12 < ne12; ++i12) {
1844
- int i03 = i13 / r3;
1845
- int i02 = i12 / r2;
1846
-
1847
- CUBLAS_CHECK(
1848
- cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1849
- ne01, ne11, ne10,
1850
- alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
1851
- src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
1852
- beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
1853
- cu_compute_type,
1854
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1855
- }
1856
- }
1857
- }
1858
- #else
1859
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1926
+ if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
1927
+ // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
1928
+ const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
1929
+ const int64_t smb = ne12 == 1 ? s13 : s12;
1930
+
1860
1931
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
1861
1932
  // use cublasGemmStridedBatchedEx
1862
1933
  CUBLAS_CHECK(
1863
1934
  cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1864
1935
  ne01, ne11, ne10,
1865
- alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1866
- src1_f16, CUDA_R_16F, s11, s12, // strideB
1867
- beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1936
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
1937
+ src1_ptr, cu_data_type_b, s11, smb, // strideB
1938
+ beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1868
1939
  ne12*ne13,
1869
1940
  cu_compute_type,
1870
1941
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1875,34 +1946,55 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1875
1946
  ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
1876
1947
  ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1877
1948
 
1949
+ size_t src1_stride_size = sizeof(cuda_t);
1950
+
1878
1951
  dim3 block_dims(ne13, ne12);
1879
1952
  k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1880
- src0_f16, src1_f16, dst_t,
1953
+ src0_ptr, src1_ptr, dst_t,
1881
1954
  ptrs_src.get(), ptrs_dst.get(),
1882
1955
  ne12, ne13,
1883
1956
  ne23,
1884
1957
  nb02, nb03,
1885
- src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
1886
- src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
1958
+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
1959
+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
1887
1960
  nbd2, nbd3,
1888
1961
  r2, r3);
1962
+
1889
1963
  CUDA_CHECK(cudaGetLastError());
1890
1964
 
1891
1965
  CUBLAS_CHECK(
1892
1966
  cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1893
1967
  ne01, ne11, ne10,
1894
- alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1895
- (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1896
- beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1968
+ alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
1969
+ (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
1970
+ beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1897
1971
  ne23,
1898
1972
  cu_compute_type,
1899
1973
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1900
1974
  }
1901
- #endif
1902
1975
 
1903
- if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1904
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1905
- to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
1976
+ // Convert output back to F32 if needed
1977
+ if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1978
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
1979
+ to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
1980
+ }
1981
+ }
1982
+
1983
+ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1984
+ GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1985
+
1986
+ switch (src0->type) {
1987
+ case GGML_TYPE_F32:
1988
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1989
+ break;
1990
+ case GGML_TYPE_BF16:
1991
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1992
+ break;
1993
+ case GGML_TYPE_F16:
1994
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1995
+ break;
1996
+ default:
1997
+ GGML_ABORT("Unsupported type");
1906
1998
  }
1907
1999
  }
1908
2000
 
@@ -1915,17 +2007,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1915
2007
  const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
1916
2008
  && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
1917
2009
 
1918
- bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
1919
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1920
- && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
2010
+ bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
2011
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2012
+ bool use_mul_mat_f = !ggml_is_quantized(src0->type)
2013
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
1921
2014
  bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
1922
2015
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1923
2016
  && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
1924
2017
  bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
1925
2018
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
1926
2019
 
1927
- bool any_gpus_with_slow_fp16 = false;
1928
- bool any_gpus_without_fp16_mma = false;
2020
+ bool any_gpus_with_slow_fp16 = false;
1929
2021
 
1930
2022
  if (split) {
1931
2023
  ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
@@ -1936,16 +2028,20 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1936
2028
  continue;
1937
2029
  }
1938
2030
 
1939
- const int cc = ggml_cuda_info().devices[id].cc;
1940
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1941
- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1942
- any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
2031
+ const int cc = ggml_cuda_info().devices[id].cc;
2032
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
2033
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2034
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
2035
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
2036
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1943
2037
  }
1944
2038
  } else {
1945
- const int cc = ggml_cuda_info().devices[ctx.device].cc;
1946
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1947
- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1948
- any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
2039
+ const int cc = ggml_cuda_info().devices[ctx.device].cc;
2040
+ const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
2041
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2042
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
2043
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
2044
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1949
2045
  }
1950
2046
 
1951
2047
  // debug helpers
@@ -1956,20 +2052,28 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1956
2052
  //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
1957
2053
  //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
1958
2054
 
1959
- if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
2055
+ //TODO update for generic tensor parallelism
2056
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2057
+ bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2058
+ bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2059
+ bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2060
+
2061
+ if (!split && use_mul_mat_vec_f) {
1960
2062
  // the custom F16 vector kernel can be used over batched cuBLAS GEMM
1961
2063
  // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
1962
- ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
2064
+ ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
2065
+ } else if (!split && use_mul_mat_f) {
2066
+ ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
1963
2067
  } else if (!split && use_mul_mat_vec_q) {
1964
2068
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
1965
2069
  } else if (!split && use_mul_mat_q) {
1966
2070
  ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
1967
- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1968
- !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2071
+ } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
2072
+ && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1969
2073
  // general KQ + KQV multi-batch without FlashAttention
1970
2074
  ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
1971
- } else if (use_mul_mat_vec) {
1972
- ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
2075
+ } else if (use_mul_mat_vec_f) {
2076
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
1973
2077
  } else if (use_mul_mat_vec_q) {
1974
2078
  ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
1975
2079
  } else if (use_mul_mat_q) {
@@ -1997,7 +2101,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
1997
2101
  if (ggml_is_quantized(src0->type)) {
1998
2102
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
1999
2103
  } else {
2000
- ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
2104
+ ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
2001
2105
  }
2002
2106
  return;
2003
2107
  }
@@ -2006,6 +2110,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2006
2110
  ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
2007
2111
  return;
2008
2112
  }
2113
+
2114
+ if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) {
2115
+ ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
2116
+ return;
2117
+ }
2009
2118
  }
2010
2119
 
2011
2120
  cudaStream_t stream = ctx.stream();
@@ -2147,6 +2256,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2147
2256
  case GGML_OP_GET_ROWS_BACK:
2148
2257
  ggml_cuda_op_get_rows_back(ctx, dst);
2149
2258
  break;
2259
+ case GGML_OP_SET_ROWS:
2260
+ ggml_cuda_op_set_rows(ctx, dst);
2261
+ break;
2150
2262
  case GGML_OP_DUP:
2151
2263
  ggml_cuda_dup(ctx, dst);
2152
2264
  break;
@@ -2160,6 +2272,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2160
2272
  case GGML_OP_ADD1: // TODO: more efficient implementation
2161
2273
  ggml_cuda_op_add(ctx, dst);
2162
2274
  break;
2275
+ case GGML_OP_ADD_ID:
2276
+ ggml_cuda_op_add_id(ctx, dst);
2277
+ break;
2163
2278
  case GGML_OP_SUB:
2164
2279
  ggml_cuda_op_sub(ctx, dst);
2165
2280
  break;
@@ -2216,6 +2331,33 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2216
2331
  case GGML_UNARY_OP_EXP:
2217
2332
  ggml_cuda_op_exp(ctx, dst);
2218
2333
  break;
2334
+ case GGML_UNARY_OP_ELU:
2335
+ ggml_cuda_op_elu(ctx, dst);
2336
+ break;
2337
+ default:
2338
+ return false;
2339
+ }
2340
+ break;
2341
+ case GGML_OP_GLU:
2342
+ switch (ggml_get_glu_op(dst)) {
2343
+ case GGML_GLU_OP_REGLU:
2344
+ ggml_cuda_op_reglu(ctx, dst);
2345
+ break;
2346
+ case GGML_GLU_OP_GEGLU:
2347
+ ggml_cuda_op_geglu(ctx, dst);
2348
+ break;
2349
+ case GGML_GLU_OP_SWIGLU:
2350
+ ggml_cuda_op_swiglu(ctx, dst);
2351
+ break;
2352
+ case GGML_GLU_OP_SWIGLU_OAI:
2353
+ ggml_cuda_op_swiglu_oai(ctx, dst);
2354
+ break;
2355
+ case GGML_GLU_OP_GEGLU_ERF:
2356
+ ggml_cuda_op_geglu_erf(ctx, dst);
2357
+ break;
2358
+ case GGML_GLU_OP_GEGLU_QUICK:
2359
+ ggml_cuda_op_geglu_quick(ctx, dst);
2360
+ break;
2219
2361
  default:
2220
2362
  return false;
2221
2363
  }
@@ -2238,6 +2380,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2238
2380
  case GGML_OP_PAD:
2239
2381
  ggml_cuda_op_pad(ctx, dst);
2240
2382
  break;
2383
+ case GGML_OP_PAD_REFLECT_1D:
2384
+ ggml_cuda_op_pad_reflect_1d(ctx, dst);
2385
+ break;
2241
2386
  case GGML_OP_ARANGE:
2242
2387
  ggml_cuda_op_arange(ctx, dst);
2243
2388
  break;
@@ -2307,9 +2452,24 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2307
2452
  case GGML_OP_ROPE_BACK:
2308
2453
  ggml_cuda_op_rope_back(ctx, dst);
2309
2454
  break;
2455
+ case GGML_OP_ROLL:
2456
+ ggml_cuda_op_roll(ctx, dst);
2457
+ break;
2310
2458
  case GGML_OP_IM2COL:
2311
2459
  ggml_cuda_op_im2col(ctx, dst);
2312
2460
  break;
2461
+ case GGML_OP_IM2COL_3D:
2462
+ ggml_cuda_op_im2col_3d(ctx, dst);
2463
+ break;
2464
+ case GGML_OP_CONV_2D:
2465
+ ggml_cuda_op_conv2d(ctx, dst);
2466
+ break;
2467
+ case GGML_OP_CONV_2D_DW:
2468
+ ggml_cuda_op_conv2d_dw(ctx, dst);
2469
+ break;
2470
+ case GGML_OP_CONV_TRANSPOSE_2D:
2471
+ ggml_cuda_conv_2d_transpose_p0(ctx, dst);
2472
+ break;
2313
2473
  case GGML_OP_CONV_TRANSPOSE_1D:
2314
2474
  ggml_cuda_op_conv_transpose_1d(ctx,dst);
2315
2475
  break;
@@ -2322,6 +2482,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2322
2482
  case GGML_OP_SUM_ROWS:
2323
2483
  ggml_cuda_op_sum_rows(ctx, dst);
2324
2484
  break;
2485
+ case GGML_OP_MEAN:
2486
+ ggml_cuda_op_mean(ctx, dst);
2487
+ break;
2325
2488
  case GGML_OP_SSM_CONV:
2326
2489
  ggml_cuda_op_ssm_conv(ctx, dst);
2327
2490
  break;
@@ -2352,6 +2515,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2352
2515
  case GGML_OP_OPT_STEP_ADAMW:
2353
2516
  ggml_cuda_opt_step_adamw(ctx, dst);
2354
2517
  break;
2518
+ case GGML_OP_OPT_STEP_SGD:
2519
+ ggml_cuda_opt_step_sgd(ctx, dst);
2520
+ break;
2355
2521
  default:
2356
2522
  return false;
2357
2523
  }
@@ -2470,6 +2636,14 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2470
2636
  // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2471
2637
  cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
2472
2638
 
2639
+ const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
2640
+ const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
2641
+ const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
2642
+ const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
2643
+ const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
2644
+ const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
2645
+ const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
2646
+
2473
2647
  for (int i = 0; i < cgraph->n_nodes; i++) {
2474
2648
  ggml_tensor * node = cgraph->nodes[i];
2475
2649
 
@@ -2491,9 +2665,20 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2491
2665
  #endif
2492
2666
  }
2493
2667
 
2494
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
2495
- // disable CUDA graphs for batch size > 1 for now.
2496
- // Changes in batch size or context size can cause changes to the grid size of some kernels.
2668
+ if (node->op == GGML_OP_ADD &&
2669
+ node->src[1] && node->src[1]->ne[1] > 1 &&
2670
+ (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
2671
+ (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
2672
+ strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
2673
+ strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
2674
+ strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
2675
+ strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
2676
+ strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
2677
+ // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
2678
+ // by means of matching node names. See
2679
+ // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
2680
+ // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2681
+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2497
2682
  use_cuda_graph = false;
2498
2683
  #ifndef NDEBUG
2499
2684
  GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
@@ -2639,13 +2824,130 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
2639
2824
  }
2640
2825
  #endif
2641
2826
 
2827
+ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
2828
+ #ifndef NDEBUG
2829
+ const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
2830
+ GGML_ASSERT(unary_ops.size() == num_unary);
2831
+ #endif
2832
+
2833
+ //TODO: remove special case once ggml_can_fuse can handle empty nodes
2834
+ std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
2835
+ std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
2836
+
2837
+ if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
2838
+
2839
+ if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
2840
+ return false;
2841
+ }
2842
+
2843
+ for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
2844
+ if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
2845
+ }
2846
+ ggml_tensor * softmax = cgraph->nodes[node_idx];
2847
+ ggml_tensor * weights = cgraph->nodes[node_idx+8];
2848
+
2849
+ if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
2850
+ return true;
2851
+ }
2852
+ }
2853
+
2854
+ if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
2855
+
2856
+ if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
2857
+ return false;
2858
+ }
2859
+
2860
+ for (size_t i = 0; i < topk_moe_ops.size(); i++) {
2861
+ if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
2862
+ }
2863
+
2864
+ ggml_tensor * softmax = cgraph->nodes[node_idx];
2865
+ ggml_tensor * weights = cgraph->nodes[node_idx+4];
2866
+ if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
2867
+ return true;
2868
+ }
2869
+ }
2870
+
2871
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
2872
+ return false;
2873
+ }
2874
+
2875
+ if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2876
+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
2877
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2878
+ const ggml_tensor *add = nullptr;
2879
+
2880
+ if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
2881
+ add = cgraph->nodes[node_idx+2];
2882
+ }
2883
+
2884
+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
2885
+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
2886
+
2887
+ //rms norm only supports F32
2888
+ if (mul->src[0]->type != GGML_TYPE_F32 ||
2889
+ mul->src[1]->type != GGML_TYPE_F32 ||
2890
+ mul->type != GGML_TYPE_F32) {
2891
+ return false;
2892
+ }
2893
+
2894
+ if (add && (add->src[0]->type != GGML_TYPE_F32 ||
2895
+ add->src[1]->type != GGML_TYPE_F32 ||
2896
+ add->type != GGML_TYPE_F32) ) {
2897
+ return false;
2898
+ }
2899
+
2900
+ //if rms norm is the B operand, then we don't handle broadcast
2901
+ if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2902
+ return false;
2903
+ }
2904
+
2905
+ //rms_norm kernel assumes contigous rows
2906
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
2907
+ return false;
2908
+ }
2909
+
2910
+ if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
2911
+ return false;
2912
+ }
2913
+
2914
+ return true;
2915
+ }
2916
+
2917
+ if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
2918
+ && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
2919
+ const ggml_tensor *scale = cgraph->nodes[node_idx];
2920
+ const ggml_tensor *tanh = cgraph->nodes[node_idx+1];
2921
+ const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
2922
+
2923
+ GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
2924
+ GGML_ASSERT(scale->type == GGML_TYPE_F32);
2925
+
2926
+ if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {
2927
+ return false;
2928
+ }
2929
+
2930
+ // Check for bias
2931
+ if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
2932
+ return false;
2933
+ }
2934
+
2935
+ return true;
2936
+ }
2937
+
2938
+ return false;
2939
+ }
2940
+
2642
2941
  static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2643
2942
  bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
2943
+ // flag used to determine whether it is an integrated_gpu
2944
+ const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
2644
2945
 
2645
2946
  while (!graph_evaluated_or_captured) {
2646
2947
  // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2647
2948
  // With the use of CUDA graphs, the execution will be performed by the graph launch.
2648
2949
  if (!use_cuda_graph || cuda_graph_update_required) {
2950
+
2649
2951
  for (int i = 0; i < cgraph->n_nodes; i++) {
2650
2952
  ggml_tensor * node = cgraph->nodes[i];
2651
2953
 
@@ -2653,16 +2955,87 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2653
2955
  continue;
2654
2956
  }
2655
2957
 
2958
+ static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
2959
+ if (!disable_fusion) {
2960
+
2961
+ if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
2962
+ ggml_tensor * weights = cgraph->nodes[i+8];
2963
+ ggml_tensor * selected_experts = cgraph->nodes[i+3];
2964
+ ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
2965
+ i += 8;
2966
+ continue;
2967
+ }
2968
+
2969
+ if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
2970
+ ggml_tensor * weights = cgraph->nodes[i+4];
2971
+ ggml_tensor * selected_experts = cgraph->nodes[i+3];
2972
+ ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
2973
+ i += 4;
2974
+ continue;
2975
+ }
2976
+
2977
+ if (node->op == GGML_OP_ADD) {
2978
+ int n_fuse = 0;
2979
+ ggml_op ops[8];
2980
+ std::fill(ops, ops + 8, GGML_OP_ADD);
2981
+
2982
+ for (; n_fuse <= 6; ++n_fuse){
2983
+ if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
2984
+ break;
2985
+ }
2986
+ if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
2987
+ break;
2988
+ }
2989
+ if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
2990
+ break;
2991
+ }
2992
+ }
2993
+
2994
+ n_fuse++;
2995
+
2996
+ if (n_fuse > 1) {
2997
+ for (int j = 0; j < n_fuse - 1; ++j) {
2998
+ node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
2999
+ }
3000
+ cgraph->nodes[i + n_fuse - 1]->data = node->data;
3001
+ ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
3002
+ i += n_fuse - 1;
3003
+
3004
+ continue;
3005
+ }
3006
+ }
3007
+
3008
+
3009
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
3010
+ ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
3011
+ i += 2;
3012
+ continue;
3013
+ }
3014
+
3015
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
3016
+ ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
3017
+ i++;
3018
+ continue;
3019
+ }
3020
+
3021
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
3022
+ i += 2;
3023
+ ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
3024
+ continue;
3025
+ }
3026
+ }
2656
3027
  #ifndef NDEBUG
2657
3028
  assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
2658
3029
  for (int j = 0; j < GGML_MAX_SRC; j++) {
2659
3030
  if (node->src[j] != nullptr) {
2660
3031
  assert(node->src[j]->buffer);
2661
3032
  assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
2662
- ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
3033
+ ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
2663
3034
  }
2664
3035
  }
2665
- #endif
3036
+ #else
3037
+ GGML_UNUSED(integrated);
3038
+ #endif // NDEBUG
2666
3039
 
2667
3040
  bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
2668
3041
  if (!ok) {
@@ -2681,6 +3054,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2681
3054
 
2682
3055
  CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
2683
3056
  graph_evaluated_or_captured = true; // CUDA graph has been captured
3057
+
3058
+ std::lock_guard<std::mutex> lock(ggml_cuda_lock);
3059
+ if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
3060
+ ggml_cuda_lock_cv.notify_all();
3061
+ }
2684
3062
  } else {
2685
3063
  graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
2686
3064
  }
@@ -2756,7 +3134,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
2756
3134
  }
2757
3135
  }
2758
3136
 
2759
- if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
3137
+ if (use_cuda_graph && cuda_graph_update_required) {
3138
+ // Start CUDA graph capture
3139
+ {
3140
+ std::lock_guard<std::mutex> lock(ggml_cuda_lock);
3141
+ ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
3142
+ }
3143
+
2760
3144
  CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
2761
3145
  }
2762
3146
 
@@ -2815,6 +3199,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
2815
3199
  /* .graph_compute = */ ggml_backend_cuda_graph_compute,
2816
3200
  /* .event_record = */ ggml_backend_cuda_event_record,
2817
3201
  /* .event_wait = */ ggml_backend_cuda_event_wait,
3202
+ /* .graph_optimize = */ NULL,
2818
3203
  };
2819
3204
 
2820
3205
  static ggml_guid_t ggml_backend_cuda_guid() {
@@ -2847,7 +3232,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
2847
3232
  return false;
2848
3233
  }
2849
3234
 
2850
- #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
3235
+ #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP)
2851
3236
  cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
2852
3237
  if (err != cudaSuccess) {
2853
3238
  // clear the error
@@ -2884,6 +3269,7 @@ struct ggml_backend_cuda_device_context {
2884
3269
  int device;
2885
3270
  std::string name;
2886
3271
  std::string description;
3272
+ std::string pci_bus_id;
2887
3273
  };
2888
3274
 
2889
3275
  static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -2908,9 +3294,12 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
2908
3294
  }
2909
3295
 
2910
3296
  static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
3297
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
3298
+
2911
3299
  props->name = ggml_backend_cuda_device_get_name(dev);
2912
3300
  props->description = ggml_backend_cuda_device_get_description(dev);
2913
3301
  props->type = ggml_backend_cuda_device_get_type(dev);
3302
+ props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
2914
3303
  ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
2915
3304
 
2916
3305
  bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
@@ -2984,19 +3373,36 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
2984
3373
  case GGML_UNARY_OP_GELU_QUICK:
2985
3374
  case GGML_UNARY_OP_TANH:
2986
3375
  case GGML_UNARY_OP_EXP:
3376
+ case GGML_UNARY_OP_ELU:
2987
3377
  return ggml_is_contiguous(op->src[0]);
2988
3378
  default:
2989
3379
  return false;
2990
3380
  }
2991
3381
  break;
3382
+ case GGML_OP_GLU:
3383
+ switch (ggml_get_glu_op(op)) {
3384
+ case GGML_GLU_OP_REGLU:
3385
+ case GGML_GLU_OP_GEGLU:
3386
+ case GGML_GLU_OP_SWIGLU:
3387
+ case GGML_GLU_OP_SWIGLU_OAI:
3388
+ case GGML_GLU_OP_GEGLU_ERF:
3389
+ case GGML_GLU_OP_GEGLU_QUICK:
3390
+ return ggml_is_contiguous_1(op->src[0]);
3391
+ default:
3392
+ return false;
3393
+ }
3394
+ break;
2992
3395
  case GGML_OP_MUL_MAT:
2993
3396
  case GGML_OP_MUL_MAT_ID:
2994
3397
  {
2995
3398
  struct ggml_tensor * a = op->src[0];
2996
3399
  struct ggml_tensor * b = op->src[1];
2997
- // for small weight matrices the active device can end up without any rows, don't use row split in those cases
2998
- // this avoids some edge cases (and the performance would not be good anyways)
2999
3400
  if (a->buffer && ggml_backend_buft_is_cuda_split(a->buffer->buft)) {
3401
+ if (a->ne[2] > 1 || a->ne[3] > 1) {
3402
+ return false;
3403
+ }
3404
+ // for small weight matrices the active device can end up without any rows, don't use row split in those cases
3405
+ // this avoids some edge cases (and the performance would not be good anyways)
3000
3406
  ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) a->buffer->buft->context;
3001
3407
  int64_t row_low;
3002
3408
  int64_t row_high;
@@ -3009,9 +3415,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3009
3415
  return false;
3010
3416
  }
3011
3417
  #ifdef GGML_USE_MUSA
3012
- if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3013
- !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3014
- return false;
3418
+ const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3419
+ if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3420
+ if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT &&
3421
+ a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
3422
+ return false;
3423
+ }
3424
+ if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID &&
3425
+ a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
3426
+ return false;
3427
+ }
3015
3428
  }
3016
3429
  #endif // GGML_USE_MUSA
3017
3430
  switch (a->type) {
@@ -3022,6 +3435,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3022
3435
  case GGML_TYPE_Q5_0:
3023
3436
  case GGML_TYPE_Q5_1:
3024
3437
  case GGML_TYPE_Q8_0:
3438
+ case GGML_TYPE_MXFP4:
3025
3439
  case GGML_TYPE_Q2_K:
3026
3440
  case GGML_TYPE_Q3_K:
3027
3441
  case GGML_TYPE_Q4_K:
@@ -3038,11 +3452,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3038
3452
  case GGML_TYPE_IQ4_NL:
3039
3453
  case GGML_TYPE_IQ4_XS:
3040
3454
  case GGML_TYPE_BF16:
3041
- #ifdef GGML_USE_MUSA
3042
- if (a->type == GGML_TYPE_Q3_K) {
3043
- return false;
3044
- }
3045
- #endif // GGML_USE_MUSA
3046
3455
  return true;
3047
3456
  default:
3048
3457
  return false;
@@ -3055,6 +3464,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3055
3464
  switch (op->src[0]->type) {
3056
3465
  case GGML_TYPE_F16:
3057
3466
  case GGML_TYPE_F32:
3467
+ case GGML_TYPE_BF16:
3468
+ case GGML_TYPE_I32:
3058
3469
  case GGML_TYPE_Q4_0:
3059
3470
  case GGML_TYPE_Q4_1:
3060
3471
  case GGML_TYPE_Q5_0:
@@ -3069,17 +3480,21 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3069
3480
  {
3070
3481
  return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
3071
3482
  } break;
3483
+ case GGML_OP_SET_ROWS:
3484
+ {
3485
+ return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
3486
+ op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
3487
+ op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
3488
+ op->src[0]->type == GGML_TYPE_F32 &&
3489
+ (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
3490
+ } break;
3072
3491
  case GGML_OP_CPY:
3073
3492
  {
3074
3493
  ggml_type src0_type = op->src[0]->type;
3075
3494
  ggml_type src1_type = op->src[1]->type;
3076
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
3077
- return true;
3078
- }
3079
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
3080
- return true;
3081
- }
3082
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
3495
+ if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) &&
3496
+ (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16)
3497
+ ) {
3083
3498
  return true;
3084
3499
  }
3085
3500
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
@@ -3115,10 +3530,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3115
3530
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
3116
3531
  return true;
3117
3532
  }
3118
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
3533
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) {
3119
3534
  return true;
3120
3535
  }
3121
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
3536
+ if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
3122
3537
  return true;
3123
3538
  }
3124
3539
  if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
@@ -3173,6 +3588,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3173
3588
  case GGML_OP_PERMUTE:
3174
3589
  case GGML_OP_TRANSPOSE:
3175
3590
  case GGML_OP_ADD:
3591
+ case GGML_OP_ADD_ID:
3176
3592
  case GGML_OP_ADD1:
3177
3593
  case GGML_OP_SUB:
3178
3594
  case GGML_OP_MUL:
@@ -3184,12 +3600,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3184
3600
  case GGML_OP_COS:
3185
3601
  case GGML_OP_CLAMP:
3186
3602
  case GGML_OP_LOG:
3187
- case GGML_OP_SSM_SCAN:
3188
- case GGML_OP_SSM_CONV:
3189
3603
  return true;
3604
+ case GGML_OP_SSM_SCAN: {
3605
+ if (op->src[3]->ne[0] == 1) {
3606
+ // Mamba2
3607
+ // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
3608
+ return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
3609
+ } else {
3610
+ // Mamba
3611
+ // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
3612
+ return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
3613
+ }
3614
+ }
3615
+ case GGML_OP_SSM_CONV: {
3616
+ // assumes d_inner % threads == 0
3617
+ return op->src[0]->ne[1] % 128 == 0;
3618
+ }
3190
3619
  case GGML_OP_CONT:
3191
- return op->src[0]->type != GGML_TYPE_BF16;
3620
+ return true;
3192
3621
  case GGML_OP_DIAG_MASK_INF:
3622
+ return true;
3193
3623
  case GGML_OP_SOFT_MAX:
3194
3624
  return true;
3195
3625
  case GGML_OP_SOFT_MAX_BACK: {
@@ -3197,22 +3627,34 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3197
3627
  memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
3198
3628
  return max_bias == 0.0f;
3199
3629
  }
3630
+ case GGML_OP_ROLL:
3631
+ if(op->src[0]->type == GGML_TYPE_F32) {
3632
+ return true;
3633
+ }
3634
+ return false;
3200
3635
  case GGML_OP_ROPE:
3201
3636
  case GGML_OP_ROPE_BACK: {
3202
3637
  return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
3203
3638
  }
3204
3639
  case GGML_OP_IM2COL:
3640
+ case GGML_OP_IM2COL_3D:
3641
+ case GGML_OP_CONV_2D:
3642
+ case GGML_OP_CONV_2D_DW:
3643
+ case GGML_OP_CONV_TRANSPOSE_2D:
3205
3644
  case GGML_OP_POOL_2D:
3206
3645
  case GGML_OP_SUM:
3207
- case GGML_OP_SUM_ROWS:
3208
- case GGML_OP_ARGSORT:
3209
3646
  case GGML_OP_ACC:
3210
3647
  return true;
3648
+ case GGML_OP_ARGSORT:
3649
+ // TODO: Support arbitrary column width
3650
+ return op->src[0]->ne[0] <= 1024;
3651
+ case GGML_OP_SUM_ROWS:
3652
+ case GGML_OP_MEAN:
3211
3653
  case GGML_OP_GROUP_NORM:
3654
+ case GGML_OP_PAD:
3212
3655
  return ggml_is_contiguous(op->src[0]);
3213
3656
  case GGML_OP_UPSCALE:
3214
- return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
3215
- case GGML_OP_PAD:
3657
+ case GGML_OP_PAD_REFLECT_1D:
3216
3658
  case GGML_OP_ARANGE:
3217
3659
  case GGML_OP_TIMESTEP_EMBEDDING:
3218
3660
  case GGML_OP_LEAKY_RELU:
@@ -3220,42 +3662,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3220
3662
  case GGML_OP_GATED_LINEAR_ATTN:
3221
3663
  case GGML_OP_RWKV_WKV7:
3222
3664
  return true;
3223
- case GGML_OP_FLASH_ATTN_EXT: {
3224
- #ifndef FLASH_ATTN_AVAILABLE
3225
- return false;
3226
- #endif // FLASH_ATTN_AVAILABLE
3227
- if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3228
- const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3229
- if (!new_mma_available(cc)) {
3230
- return false;
3231
- }
3232
- const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
3233
- return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
3234
- }
3235
- if (op->src[0]->ne[0] == 192) {
3236
- return false;
3237
- }
3238
- if (op->src[0]->ne[3] != 1) {
3239
- return false;
3240
- }
3241
- if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
3242
- return false;
3243
- }
3244
- if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
3245
- return true;
3246
- }
3247
- if (op->src[0]->ne[0] == 128) {
3248
- return true;
3249
- }
3250
- if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
3251
- return true;
3252
- }
3253
- return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
3254
- op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
3255
- }
3665
+ case GGML_OP_FLASH_ATTN_EXT:
3666
+ return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
3256
3667
  case GGML_OP_CROSS_ENTROPY_LOSS:
3257
3668
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
3258
3669
  case GGML_OP_OPT_STEP_ADAMW:
3670
+ case GGML_OP_OPT_STEP_SGD:
3259
3671
  return true;
3260
3672
  default:
3261
3673
  return false;
@@ -3263,7 +3675,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3263
3675
  }
3264
3676
 
3265
3677
  static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
3266
- return (ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev;
3678
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
3679
+ const bool integrated = ggml_cuda_info().devices[dev_ctx->device].integrated;
3680
+ return (((ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev) || (integrated && ggml_backend_buft_is_cuda_host(buft)));
3267
3681
  }
3268
3682
 
3269
3683
  static int64_t get_op_batch_size(const ggml_tensor * op) {
@@ -3385,10 +3799,6 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
3385
3799
  features.push_back({ "NO_PEER_COPY", "1" });
3386
3800
  #endif
3387
3801
 
3388
- #ifdef GGML_CUDA_F16
3389
- features.push_back({ "F16", "1" });
3390
- #endif
3391
-
3392
3802
  #ifdef GGML_CUDA_USE_GRAPHS
3393
3803
  features.push_back({ "USE_GRAPHS", "1" });
3394
3804
  #endif
@@ -3459,6 +3869,10 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
3459
3869
  CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
3460
3870
  dev_ctx->description = prop.name;
3461
3871
 
3872
+ char pci_bus_id[16] = {};
3873
+ snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
3874
+ dev_ctx->pci_bus_id = pci_bus_id;
3875
+
3462
3876
  ggml_backend_dev_t dev = new ggml_backend_device {
3463
3877
  /* .iface = */ ggml_backend_cuda_device_interface,
3464
3878
  /* .reg = */ &reg,
@@ -3493,10 +3907,10 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
3493
3907
  }
3494
3908
 
3495
3909
  ggml_backend_t cuda_backend = new ggml_backend {
3496
- /* .guid = */ ggml_backend_cuda_guid(),
3497
- /* .interface = */ ggml_backend_cuda_interface,
3498
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
3499
- /* .context = */ ctx,
3910
+ /* .guid = */ ggml_backend_cuda_guid(),
3911
+ /* .iface = */ ggml_backend_cuda_interface,
3912
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
3913
+ /* .context = */ ctx,
3500
3914
  };
3501
3915
 
3502
3916
  return cuda_backend;