whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -4,6 +4,7 @@
4
4
 
5
5
  #include "ggml-cuda/common.cuh"
6
6
  #include "ggml-cuda/acc.cuh"
7
+ #include "ggml-cuda/add-id.cuh"
7
8
  #include "ggml-cuda/arange.cuh"
8
9
  #include "ggml-cuda/argmax.cuh"
9
10
  #include "ggml-cuda/argsort.cuh"
@@ -11,6 +12,7 @@
11
12
  #include "ggml-cuda/clamp.cuh"
12
13
  #include "ggml-cuda/concat.cuh"
13
14
  #include "ggml-cuda/conv-transpose-1d.cuh"
15
+ #include "ggml-cuda/conv2d.cuh"
14
16
  #include "ggml-cuda/conv2d-dw.cuh"
15
17
  #include "ggml-cuda/conv2d-transpose.cuh"
16
18
  #include "ggml-cuda/convert.cuh"
@@ -21,17 +23,21 @@
21
23
  #include "ggml-cuda/fattn.cuh"
22
24
  #include "ggml-cuda/getrows.cuh"
23
25
  #include "ggml-cuda/im2col.cuh"
26
+ #include "ggml-cuda/mmf.cuh"
24
27
  #include "ggml-cuda/mmq.cuh"
25
- #include "ggml-cuda/mmv.cuh"
28
+ #include "ggml-cuda/mmvf.cuh"
26
29
  #include "ggml-cuda/mmvq.cuh"
27
30
  #include "ggml-cuda/norm.cuh"
28
31
  #include "ggml-cuda/opt-step-adamw.cuh"
32
+ #include "ggml-cuda/opt-step-sgd.cuh"
29
33
  #include "ggml-cuda/out-prod.cuh"
30
34
  #include "ggml-cuda/pad.cuh"
31
35
  #include "ggml-cuda/pool2d.cuh"
32
36
  #include "ggml-cuda/quantize.cuh"
33
37
  #include "ggml-cuda/rope.cuh"
38
+ #include "ggml-cuda/roll.cuh"
34
39
  #include "ggml-cuda/scale.cuh"
40
+ #include "ggml-cuda/softcap.cuh"
35
41
  #include "ggml-cuda/softmax.cuh"
36
42
  #include "ggml-cuda/ssm-conv.cuh"
37
43
  #include "ggml-cuda/ssm-scan.cuh"
@@ -39,10 +45,13 @@
39
45
  #include "ggml-cuda/sumrows.cuh"
40
46
  #include "ggml-cuda/mean.cuh"
41
47
  #include "ggml-cuda/tsembd.cuh"
48
+ #include "ggml-cuda/topk-moe.cuh"
42
49
  #include "ggml-cuda/unary.cuh"
43
50
  #include "ggml-cuda/upscale.cuh"
44
51
  #include "ggml-cuda/wkv.cuh"
45
52
  #include "ggml-cuda/gla.cuh"
53
+ #include "ggml-cuda/set-rows.cuh"
54
+ #include "ggml-cuda/pad_reflect_1d.cuh"
46
55
  #include "ggml.h"
47
56
 
48
57
  #include <algorithm>
@@ -54,6 +63,7 @@
54
63
  #include <cstddef>
55
64
  #include <cstdint>
56
65
  #include <float.h>
66
+ #include <initializer_list>
57
67
  #include <limits>
58
68
  #include <map>
59
69
  #include <memory>
@@ -124,7 +134,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
124
134
  return err;
125
135
  }
126
136
 
127
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
137
+ #if defined(GGML_USE_HIP)
128
138
  static int ggml_cuda_parse_id(char devName[]) {
129
139
  // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
130
140
  // these values are not stable so this is susceptible to breakage
@@ -171,33 +181,9 @@ static int ggml_cuda_parse_id(char devName[]) {
171
181
  archNum += archMinor;
172
182
  return archNum;
173
183
  }
174
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
184
+ #endif // defined(GGML_USE_HIP)
175
185
 
176
186
  static ggml_cuda_device_info ggml_cuda_init() {
177
- #ifdef __HIP_PLATFORM_AMD__
178
- // Workaround for a rocBLAS bug when using multiple graphics cards:
179
- // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
180
- {
181
- int major_version = 0;
182
- size_t version_length = 0;
183
- if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
184
- std::vector<char> version(version_length+1, '\0');
185
- if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
186
- version.resize(::strlen(version.data()));
187
- int parsed_value = 0;
188
- if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) {
189
- major_version = parsed_value;
190
- }
191
- }
192
- }
193
- if (major_version < 4) {
194
- GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n");
195
- rocblas_initialize();
196
- CUDA_CHECK(cudaDeviceSynchronize());
197
- }
198
- }
199
- #endif
200
-
201
187
  ggml_cuda_device_info info = {};
202
188
 
203
189
  cudaError_t err = cudaGetDeviceCount(&info.device_count);
@@ -220,6 +206,8 @@ static ggml_cuda_device_info ggml_cuda_init() {
220
206
  GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
221
207
  #endif // GGML_CUDA_FORCE_CUBLAS
222
208
  GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
209
+
210
+ std::vector<std::pair<int, std::string>> turing_devices_without_mma;
223
211
  for (int id = 0; id < info.device_count; ++id) {
224
212
  int device_vmm = 0;
225
213
 
@@ -247,7 +235,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
247
235
  info.devices[id].nsm = prop.multiProcessorCount;
248
236
  info.devices[id].smpb = prop.sharedMemPerBlock;
249
237
  info.devices[id].warp_size = prop.warpSize;
250
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
238
+ #if defined(GGML_USE_HIP)
251
239
  info.devices[id].smpbo = prop.sharedMemPerBlock;
252
240
 
253
241
  info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
@@ -277,7 +265,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
277
265
  info.devices[id].cc = 100*prop.major + 10*prop.minor;
278
266
  GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
279
267
  id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
280
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
268
+ std::string device_name(prop.name);
269
+ if (device_name == "NVIDIA GeForce MX450") {
270
+ turing_devices_without_mma.push_back({ id, device_name });
271
+ } else if (device_name == "NVIDIA GeForce MX550") {
272
+ turing_devices_without_mma.push_back({ id, device_name });
273
+ } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
274
+ turing_devices_without_mma.push_back({ id, device_name });
275
+ }
276
+ #endif // defined(GGML_USE_HIP)
277
+ }
278
+
279
+ if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) {
280
+ GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n");
281
+ for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) {
282
+ GGML_LOG_INFO(
283
+ " Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str());
284
+ }
285
+ GGML_LOG_INFO(
286
+ "Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n");
281
287
  }
282
288
 
283
289
  for (int id = 0; id < info.device_count; ++id) {
@@ -1345,9 +1351,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1345
1351
  &beta, dst_dd_i, ldc));
1346
1352
  }
1347
1353
 
1348
- GGML_UNUSED(dst);
1349
- GGML_UNUSED(src1_ddq_i);
1350
- GGML_UNUSED(src1_padded_row_size);
1354
+ GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size);
1351
1355
  }
1352
1356
 
1353
1357
  static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
@@ -1848,6 +1852,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1848
1852
  ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
1849
1853
  ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
1850
1854
 
1855
+ bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
1856
+ bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
1857
+
1851
1858
  // Handle src0
1852
1859
  src0_ptr = (const cuda_t *) src0->data;
1853
1860
 
@@ -1866,6 +1873,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1866
1873
  s11 = ne10;
1867
1874
  s12 = ne11*s11;
1868
1875
  s13 = ne12*s12;
1876
+
1877
+ is_src1_cont_2 = true;
1869
1878
  }
1870
1879
 
1871
1880
  // Setup destination buffer
@@ -1914,15 +1923,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1914
1923
  const int64_t r2 = ne12/ne02;
1915
1924
  const int64_t r3 = ne13/ne03;
1916
1925
 
1917
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1926
+ if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
1927
+ // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
1928
+ const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
1929
+ const int64_t smb = ne12 == 1 ? s13 : s12;
1930
+
1918
1931
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
1919
1932
  // use cublasGemmStridedBatchedEx
1920
1933
  CUBLAS_CHECK(
1921
1934
  cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1922
1935
  ne01, ne11, ne10,
1923
- alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1924
- src1_ptr, cu_data_type_b, s11, s12, // strideB
1925
- beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1936
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
1937
+ src1_ptr, cu_data_type_b, s11, smb, // strideB
1938
+ beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1926
1939
  ne12*ne13,
1927
1940
  cu_compute_type,
1928
1941
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1994,7 +2007,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1994
2007
  const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
1995
2008
  && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
1996
2009
 
1997
- bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
2010
+ bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
2011
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2012
+ bool use_mul_mat_f = !ggml_is_quantized(src0->type)
1998
2013
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
1999
2014
  bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
2000
2015
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
@@ -2014,14 +2029,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2014
2029
  }
2015
2030
 
2016
2031
  const int cc = ggml_cuda_info().devices[id].cc;
2032
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
2017
2033
  use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2018
- use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
2034
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
2035
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
2019
2036
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2020
2037
  }
2021
2038
  } else {
2022
2039
  const int cc = ggml_cuda_info().devices[ctx.device].cc;
2040
+ const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
2023
2041
  use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2024
- use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
2042
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
2043
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
2025
2044
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2026
2045
  }
2027
2046
 
@@ -2034,15 +2053,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2034
2053
  //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
2035
2054
 
2036
2055
  //TODO update for generic tensor parallelism
2037
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2056
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2038
2057
  bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039
2058
  bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2040
2059
  bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2041
2060
 
2042
- if (!split && use_mul_mat_vec) {
2061
+ if (!split && use_mul_mat_vec_f) {
2043
2062
  // the custom F16 vector kernel can be used over batched cuBLAS GEMM
2044
2063
  // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
2045
- ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
2064
+ ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
2065
+ } else if (!split && use_mul_mat_f) {
2066
+ ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
2046
2067
  } else if (!split && use_mul_mat_vec_q) {
2047
2068
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
2048
2069
  } else if (!split && use_mul_mat_q) {
@@ -2051,8 +2072,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2051
2072
  && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2052
2073
  // general KQ + KQV multi-batch without FlashAttention
2053
2074
  ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
2054
- } else if (use_mul_mat_vec) {
2055
- ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
2075
+ } else if (use_mul_mat_vec_f) {
2076
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
2056
2077
  } else if (use_mul_mat_vec_q) {
2057
2078
  ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
2058
2079
  } else if (use_mul_mat_q) {
@@ -2080,7 +2101,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2080
2101
  if (ggml_is_quantized(src0->type)) {
2081
2102
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
2082
2103
  } else {
2083
- ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
2104
+ ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
2084
2105
  }
2085
2106
  return;
2086
2107
  }
@@ -2089,6 +2110,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2089
2110
  ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
2090
2111
  return;
2091
2112
  }
2113
+
2114
+ if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) {
2115
+ ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
2116
+ return;
2117
+ }
2092
2118
  }
2093
2119
 
2094
2120
  cudaStream_t stream = ctx.stream();
@@ -2230,6 +2256,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2230
2256
  case GGML_OP_GET_ROWS_BACK:
2231
2257
  ggml_cuda_op_get_rows_back(ctx, dst);
2232
2258
  break;
2259
+ case GGML_OP_SET_ROWS:
2260
+ ggml_cuda_op_set_rows(ctx, dst);
2261
+ break;
2233
2262
  case GGML_OP_DUP:
2234
2263
  ggml_cuda_dup(ctx, dst);
2235
2264
  break;
@@ -2243,6 +2272,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2243
2272
  case GGML_OP_ADD1: // TODO: more efficient implementation
2244
2273
  ggml_cuda_op_add(ctx, dst);
2245
2274
  break;
2275
+ case GGML_OP_ADD_ID:
2276
+ ggml_cuda_op_add_id(ctx, dst);
2277
+ break;
2246
2278
  case GGML_OP_SUB:
2247
2279
  ggml_cuda_op_sub(ctx, dst);
2248
2280
  break;
@@ -2299,6 +2331,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2299
2331
  case GGML_UNARY_OP_EXP:
2300
2332
  ggml_cuda_op_exp(ctx, dst);
2301
2333
  break;
2334
+ case GGML_UNARY_OP_ELU:
2335
+ ggml_cuda_op_elu(ctx, dst);
2336
+ break;
2302
2337
  default:
2303
2338
  return false;
2304
2339
  }
@@ -2314,6 +2349,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2314
2349
  case GGML_GLU_OP_SWIGLU:
2315
2350
  ggml_cuda_op_swiglu(ctx, dst);
2316
2351
  break;
2352
+ case GGML_GLU_OP_SWIGLU_OAI:
2353
+ ggml_cuda_op_swiglu_oai(ctx, dst);
2354
+ break;
2355
+ case GGML_GLU_OP_GEGLU_ERF:
2356
+ ggml_cuda_op_geglu_erf(ctx, dst);
2357
+ break;
2358
+ case GGML_GLU_OP_GEGLU_QUICK:
2359
+ ggml_cuda_op_geglu_quick(ctx, dst);
2360
+ break;
2317
2361
  default:
2318
2362
  return false;
2319
2363
  }
@@ -2336,6 +2380,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2336
2380
  case GGML_OP_PAD:
2337
2381
  ggml_cuda_op_pad(ctx, dst);
2338
2382
  break;
2383
+ case GGML_OP_PAD_REFLECT_1D:
2384
+ ggml_cuda_op_pad_reflect_1d(ctx, dst);
2385
+ break;
2339
2386
  case GGML_OP_ARANGE:
2340
2387
  ggml_cuda_op_arange(ctx, dst);
2341
2388
  break;
@@ -2405,9 +2452,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2405
2452
  case GGML_OP_ROPE_BACK:
2406
2453
  ggml_cuda_op_rope_back(ctx, dst);
2407
2454
  break;
2455
+ case GGML_OP_ROLL:
2456
+ ggml_cuda_op_roll(ctx, dst);
2457
+ break;
2408
2458
  case GGML_OP_IM2COL:
2409
2459
  ggml_cuda_op_im2col(ctx, dst);
2410
2460
  break;
2461
+ case GGML_OP_IM2COL_3D:
2462
+ ggml_cuda_op_im2col_3d(ctx, dst);
2463
+ break;
2464
+ case GGML_OP_CONV_2D:
2465
+ ggml_cuda_op_conv2d(ctx, dst);
2466
+ break;
2411
2467
  case GGML_OP_CONV_2D_DW:
2412
2468
  ggml_cuda_op_conv2d_dw(ctx, dst);
2413
2469
  break;
@@ -2459,6 +2515,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2459
2515
  case GGML_OP_OPT_STEP_ADAMW:
2460
2516
  ggml_cuda_opt_step_adamw(ctx, dst);
2461
2517
  break;
2518
+ case GGML_OP_OPT_STEP_SGD:
2519
+ ggml_cuda_opt_step_sgd(ctx, dst);
2520
+ break;
2462
2521
  default:
2463
2522
  return false;
2464
2523
  }
@@ -2577,6 +2636,14 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2577
2636
  // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2578
2637
  cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
2579
2638
 
2639
+ const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
2640
+ const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
2641
+ const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
2642
+ const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
2643
+ const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
2644
+ const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
2645
+ const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
2646
+
2580
2647
  for (int i = 0; i < cgraph->n_nodes; i++) {
2581
2648
  ggml_tensor * node = cgraph->nodes[i];
2582
2649
 
@@ -2598,9 +2665,20 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2598
2665
  #endif
2599
2666
  }
2600
2667
 
2601
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
2602
- // disable CUDA graphs for batch size > 1 for now.
2603
- // Changes in batch size or context size can cause changes to the grid size of some kernels.
2668
+ if (node->op == GGML_OP_ADD &&
2669
+ node->src[1] && node->src[1]->ne[1] > 1 &&
2670
+ (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
2671
+ (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
2672
+ strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
2673
+ strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
2674
+ strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
2675
+ strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
2676
+ strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
2677
+ // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
2678
+ // by means of matching node names. See
2679
+ // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
2680
+ // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2681
+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2604
2682
  use_cuda_graph = false;
2605
2683
  #ifndef NDEBUG
2606
2684
  GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
@@ -2746,6 +2824,120 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
2746
2824
  }
2747
2825
  #endif
2748
2826
 
2827
+ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
2828
+ #ifndef NDEBUG
2829
+ const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
2830
+ GGML_ASSERT(unary_ops.size() == num_unary);
2831
+ #endif
2832
+
2833
+ //TODO: remove special case once ggml_can_fuse can handle empty nodes
2834
+ std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
2835
+ std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
2836
+
2837
+ if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
2838
+
2839
+ if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
2840
+ return false;
2841
+ }
2842
+
2843
+ for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
2844
+ if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
2845
+ }
2846
+ ggml_tensor * softmax = cgraph->nodes[node_idx];
2847
+ ggml_tensor * weights = cgraph->nodes[node_idx+8];
2848
+
2849
+ if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
2850
+ return true;
2851
+ }
2852
+ }
2853
+
2854
+ if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
2855
+
2856
+ if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
2857
+ return false;
2858
+ }
2859
+
2860
+ for (size_t i = 0; i < topk_moe_ops.size(); i++) {
2861
+ if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
2862
+ }
2863
+
2864
+ ggml_tensor * softmax = cgraph->nodes[node_idx];
2865
+ ggml_tensor * weights = cgraph->nodes[node_idx+4];
2866
+ if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
2867
+ return true;
2868
+ }
2869
+ }
2870
+
2871
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
2872
+ return false;
2873
+ }
2874
+
2875
+ if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2876
+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
2877
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
2878
+ const ggml_tensor *add = nullptr;
2879
+
2880
+ if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
2881
+ add = cgraph->nodes[node_idx+2];
2882
+ }
2883
+
2884
+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
2885
+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
2886
+
2887
+ //rms norm only supports F32
2888
+ if (mul->src[0]->type != GGML_TYPE_F32 ||
2889
+ mul->src[1]->type != GGML_TYPE_F32 ||
2890
+ mul->type != GGML_TYPE_F32) {
2891
+ return false;
2892
+ }
2893
+
2894
+ if (add && (add->src[0]->type != GGML_TYPE_F32 ||
2895
+ add->src[1]->type != GGML_TYPE_F32 ||
2896
+ add->type != GGML_TYPE_F32) ) {
2897
+ return false;
2898
+ }
2899
+
2900
+ //if rms norm is the B operand, then we don't handle broadcast
2901
+ if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
2902
+ return false;
2903
+ }
2904
+
2905
+ //rms_norm kernel assumes contigous rows
2906
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
2907
+ return false;
2908
+ }
2909
+
2910
+ if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
2911
+ return false;
2912
+ }
2913
+
2914
+ return true;
2915
+ }
2916
+
2917
+ if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
2918
+ && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
2919
+ const ggml_tensor *scale = cgraph->nodes[node_idx];
2920
+ const ggml_tensor *tanh = cgraph->nodes[node_idx+1];
2921
+ const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
2922
+
2923
+ GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
2924
+ GGML_ASSERT(scale->type == GGML_TYPE_F32);
2925
+
2926
+ if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {
2927
+ return false;
2928
+ }
2929
+
2930
+ // Check for bias
2931
+ if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
2932
+ return false;
2933
+ }
2934
+
2935
+ return true;
2936
+ }
2937
+
2938
+ return false;
2939
+ }
2940
+
2749
2941
  static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2750
2942
  bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
2751
2943
  // flag used to determine whether it is an integrated_gpu
@@ -2755,6 +2947,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2755
2947
  // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2756
2948
  // With the use of CUDA graphs, the execution will be performed by the graph launch.
2757
2949
  if (!use_cuda_graph || cuda_graph_update_required) {
2950
+
2758
2951
  for (int i = 0; i < cgraph->n_nodes; i++) {
2759
2952
  ggml_tensor * node = cgraph->nodes[i];
2760
2953
 
@@ -2762,6 +2955,75 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2762
2955
  continue;
2763
2956
  }
2764
2957
 
2958
+ static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
2959
+ if (!disable_fusion) {
2960
+
2961
+ if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
2962
+ ggml_tensor * weights = cgraph->nodes[i+8];
2963
+ ggml_tensor * selected_experts = cgraph->nodes[i+3];
2964
+ ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
2965
+ i += 8;
2966
+ continue;
2967
+ }
2968
+
2969
+ if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
2970
+ ggml_tensor * weights = cgraph->nodes[i+4];
2971
+ ggml_tensor * selected_experts = cgraph->nodes[i+3];
2972
+ ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
2973
+ i += 4;
2974
+ continue;
2975
+ }
2976
+
2977
+ if (node->op == GGML_OP_ADD) {
2978
+ int n_fuse = 0;
2979
+ ggml_op ops[8];
2980
+ std::fill(ops, ops + 8, GGML_OP_ADD);
2981
+
2982
+ for (; n_fuse <= 6; ++n_fuse){
2983
+ if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
2984
+ break;
2985
+ }
2986
+ if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
2987
+ break;
2988
+ }
2989
+ if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
2990
+ break;
2991
+ }
2992
+ }
2993
+
2994
+ n_fuse++;
2995
+
2996
+ if (n_fuse > 1) {
2997
+ for (int j = 0; j < n_fuse - 1; ++j) {
2998
+ node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
2999
+ }
3000
+ cgraph->nodes[i + n_fuse - 1]->data = node->data;
3001
+ ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
3002
+ i += n_fuse - 1;
3003
+
3004
+ continue;
3005
+ }
3006
+ }
3007
+
3008
+
3009
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
3010
+ ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
3011
+ i += 2;
3012
+ continue;
3013
+ }
3014
+
3015
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
3016
+ ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
3017
+ i++;
3018
+ continue;
3019
+ }
3020
+
3021
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
3022
+ i += 2;
3023
+ ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
3024
+ continue;
3025
+ }
3026
+ }
2765
3027
  #ifndef NDEBUG
2766
3028
  assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
2767
3029
  for (int j = 0; j < GGML_MAX_SRC; j++) {
@@ -2937,6 +3199,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
2937
3199
  /* .graph_compute = */ ggml_backend_cuda_graph_compute,
2938
3200
  /* .event_record = */ ggml_backend_cuda_event_record,
2939
3201
  /* .event_wait = */ ggml_backend_cuda_event_wait,
3202
+ /* .graph_optimize = */ NULL,
2940
3203
  };
2941
3204
 
2942
3205
  static ggml_guid_t ggml_backend_cuda_guid() {
@@ -2969,7 +3232,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
2969
3232
  return false;
2970
3233
  }
2971
3234
 
2972
- #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
3235
+ #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP)
2973
3236
  cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
2974
3237
  if (err != cudaSuccess) {
2975
3238
  // clear the error
@@ -3006,6 +3269,7 @@ struct ggml_backend_cuda_device_context {
3006
3269
  int device;
3007
3270
  std::string name;
3008
3271
  std::string description;
3272
+ std::string pci_bus_id;
3009
3273
  };
3010
3274
 
3011
3275
  static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -3030,9 +3294,12 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
3030
3294
  }
3031
3295
 
3032
3296
  static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
3297
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
3298
+
3033
3299
  props->name = ggml_backend_cuda_device_get_name(dev);
3034
3300
  props->description = ggml_backend_cuda_device_get_description(dev);
3035
3301
  props->type = ggml_backend_cuda_device_get_type(dev);
3302
+ props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
3036
3303
  ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
3037
3304
 
3038
3305
  bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
@@ -3106,6 +3373,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3106
3373
  case GGML_UNARY_OP_GELU_QUICK:
3107
3374
  case GGML_UNARY_OP_TANH:
3108
3375
  case GGML_UNARY_OP_EXP:
3376
+ case GGML_UNARY_OP_ELU:
3109
3377
  return ggml_is_contiguous(op->src[0]);
3110
3378
  default:
3111
3379
  return false;
@@ -3116,6 +3384,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3116
3384
  case GGML_GLU_OP_REGLU:
3117
3385
  case GGML_GLU_OP_GEGLU:
3118
3386
  case GGML_GLU_OP_SWIGLU:
3387
+ case GGML_GLU_OP_SWIGLU_OAI:
3388
+ case GGML_GLU_OP_GEGLU_ERF:
3389
+ case GGML_GLU_OP_GEGLU_QUICK:
3119
3390
  return ggml_is_contiguous_1(op->src[0]);
3120
3391
  default:
3121
3392
  return false;
@@ -3164,6 +3435,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3164
3435
  case GGML_TYPE_Q5_0:
3165
3436
  case GGML_TYPE_Q5_1:
3166
3437
  case GGML_TYPE_Q8_0:
3438
+ case GGML_TYPE_MXFP4:
3167
3439
  case GGML_TYPE_Q2_K:
3168
3440
  case GGML_TYPE_Q3_K:
3169
3441
  case GGML_TYPE_Q4_K:
@@ -3192,6 +3464,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3192
3464
  switch (op->src[0]->type) {
3193
3465
  case GGML_TYPE_F16:
3194
3466
  case GGML_TYPE_F32:
3467
+ case GGML_TYPE_BF16:
3468
+ case GGML_TYPE_I32:
3195
3469
  case GGML_TYPE_Q4_0:
3196
3470
  case GGML_TYPE_Q4_1:
3197
3471
  case GGML_TYPE_Q5_0:
@@ -3206,17 +3480,21 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3206
3480
  {
3207
3481
  return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
3208
3482
  } break;
3483
+ case GGML_OP_SET_ROWS:
3484
+ {
3485
+ return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
3486
+ op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
3487
+ op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
3488
+ op->src[0]->type == GGML_TYPE_F32 &&
3489
+ (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
3490
+ } break;
3209
3491
  case GGML_OP_CPY:
3210
3492
  {
3211
3493
  ggml_type src0_type = op->src[0]->type;
3212
3494
  ggml_type src1_type = op->src[1]->type;
3213
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
3214
- return true;
3215
- }
3216
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
3217
- return true;
3218
- }
3219
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
3495
+ if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) &&
3496
+ (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16)
3497
+ ) {
3220
3498
  return true;
3221
3499
  }
3222
3500
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
@@ -3252,10 +3530,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3252
3530
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
3253
3531
  return true;
3254
3532
  }
3255
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
3533
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) {
3256
3534
  return true;
3257
3535
  }
3258
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
3536
+ if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
3259
3537
  return true;
3260
3538
  }
3261
3539
  if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
@@ -3310,6 +3588,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3310
3588
  case GGML_OP_PERMUTE:
3311
3589
  case GGML_OP_TRANSPOSE:
3312
3590
  case GGML_OP_ADD:
3591
+ case GGML_OP_ADD_ID:
3313
3592
  case GGML_OP_ADD1:
3314
3593
  case GGML_OP_SUB:
3315
3594
  case GGML_OP_MUL:
@@ -3321,12 +3600,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3321
3600
  case GGML_OP_COS:
3322
3601
  case GGML_OP_CLAMP:
3323
3602
  case GGML_OP_LOG:
3324
- case GGML_OP_SSM_SCAN:
3325
- case GGML_OP_SSM_CONV:
3326
3603
  return true;
3604
+ case GGML_OP_SSM_SCAN: {
3605
+ if (op->src[3]->ne[0] == 1) {
3606
+ // Mamba2
3607
+ // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
3608
+ return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
3609
+ } else {
3610
+ // Mamba
3611
+ // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
3612
+ return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
3613
+ }
3614
+ }
3615
+ case GGML_OP_SSM_CONV: {
3616
+ // assumes d_inner % threads == 0
3617
+ return op->src[0]->ne[1] % 128 == 0;
3618
+ }
3327
3619
  case GGML_OP_CONT:
3328
- return op->src[0]->type != GGML_TYPE_BF16;
3620
+ return true;
3329
3621
  case GGML_OP_DIAG_MASK_INF:
3622
+ return true;
3330
3623
  case GGML_OP_SOFT_MAX:
3331
3624
  return true;
3332
3625
  case GGML_OP_SOFT_MAX_BACK: {
@@ -3334,25 +3627,34 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3334
3627
  memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
3335
3628
  return max_bias == 0.0f;
3336
3629
  }
3630
+ case GGML_OP_ROLL:
3631
+ if(op->src[0]->type == GGML_TYPE_F32) {
3632
+ return true;
3633
+ }
3634
+ return false;
3337
3635
  case GGML_OP_ROPE:
3338
3636
  case GGML_OP_ROPE_BACK: {
3339
3637
  return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
3340
3638
  }
3341
3639
  case GGML_OP_IM2COL:
3640
+ case GGML_OP_IM2COL_3D:
3641
+ case GGML_OP_CONV_2D:
3342
3642
  case GGML_OP_CONV_2D_DW:
3343
3643
  case GGML_OP_CONV_TRANSPOSE_2D:
3344
3644
  case GGML_OP_POOL_2D:
3345
3645
  case GGML_OP_SUM:
3346
- case GGML_OP_SUM_ROWS:
3347
- case GGML_OP_MEAN:
3348
- case GGML_OP_ARGSORT:
3349
3646
  case GGML_OP_ACC:
3350
3647
  return true;
3648
+ case GGML_OP_ARGSORT:
3649
+ // TODO: Support arbitrary column width
3650
+ return op->src[0]->ne[0] <= 1024;
3651
+ case GGML_OP_SUM_ROWS:
3652
+ case GGML_OP_MEAN:
3351
3653
  case GGML_OP_GROUP_NORM:
3654
+ case GGML_OP_PAD:
3352
3655
  return ggml_is_contiguous(op->src[0]);
3353
3656
  case GGML_OP_UPSCALE:
3354
- return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
3355
- case GGML_OP_PAD:
3657
+ case GGML_OP_PAD_REFLECT_1D:
3356
3658
  case GGML_OP_ARANGE:
3357
3659
  case GGML_OP_TIMESTEP_EMBEDDING:
3358
3660
  case GGML_OP_LEAKY_RELU:
@@ -3360,42 +3662,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3360
3662
  case GGML_OP_GATED_LINEAR_ATTN:
3361
3663
  case GGML_OP_RWKV_WKV7:
3362
3664
  return true;
3363
- case GGML_OP_FLASH_ATTN_EXT: {
3364
- #ifndef FLASH_ATTN_AVAILABLE
3365
- return false;
3366
- #endif // FLASH_ATTN_AVAILABLE
3367
- if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3368
- const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3369
- if (!new_mma_available(cc)) {
3370
- return false;
3371
- }
3372
- const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
3373
- return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
3374
- }
3375
- if (op->src[0]->ne[0] == 192) {
3376
- return false;
3377
- }
3378
- if (op->src[0]->ne[3] != 1) {
3379
- return false;
3380
- }
3381
- if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
3382
- return false;
3383
- }
3384
- if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
3385
- return true;
3386
- }
3387
- if (op->src[0]->ne[0] == 128) {
3388
- return true;
3389
- }
3390
- if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
3391
- return true;
3392
- }
3393
- return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
3394
- op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
3395
- }
3665
+ case GGML_OP_FLASH_ATTN_EXT:
3666
+ return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
3396
3667
  case GGML_OP_CROSS_ENTROPY_LOSS:
3397
3668
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
3398
3669
  case GGML_OP_OPT_STEP_ADAMW:
3670
+ case GGML_OP_OPT_STEP_SGD:
3399
3671
  return true;
3400
3672
  default:
3401
3673
  return false;
@@ -3527,10 +3799,6 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
3527
3799
  features.push_back({ "NO_PEER_COPY", "1" });
3528
3800
  #endif
3529
3801
 
3530
- #ifdef GGML_CUDA_F16
3531
- features.push_back({ "F16", "1" });
3532
- #endif
3533
-
3534
3802
  #ifdef GGML_CUDA_USE_GRAPHS
3535
3803
  features.push_back({ "USE_GRAPHS", "1" });
3536
3804
  #endif
@@ -3601,6 +3869,10 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
3601
3869
  CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
3602
3870
  dev_ctx->description = prop.name;
3603
3871
 
3872
+ char pci_bus_id[16] = {};
3873
+ snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
3874
+ dev_ctx->pci_bus_id = pci_bus_id;
3875
+
3604
3876
  ggml_backend_dev_t dev = new ggml_backend_device {
3605
3877
  /* .iface = */ ggml_backend_cuda_device_interface,
3606
3878
  /* .reg = */ &reg,
@@ -3635,10 +3907,10 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
3635
3907
  }
3636
3908
 
3637
3909
  ggml_backend_t cuda_backend = new ggml_backend {
3638
- /* .guid = */ ggml_backend_cuda_guid(),
3639
- /* .interface = */ ggml_backend_cuda_interface,
3640
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
3641
- /* .context = */ ctx,
3910
+ /* .guid = */ ggml_backend_cuda_guid(),
3911
+ /* .iface = */ ggml_backend_cuda_interface,
3912
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
3913
+ /* .context = */ ctx,
3642
3914
  };
3643
3915
 
3644
3916
  return cuda_backend;