whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -11,6 +11,8 @@
11
11
  #include "ggml-cuda/clamp.cuh"
12
12
  #include "ggml-cuda/concat.cuh"
13
13
  #include "ggml-cuda/conv-transpose-1d.cuh"
14
+ #include "ggml-cuda/conv2d-dw.cuh"
15
+ #include "ggml-cuda/conv2d-transpose.cuh"
14
16
  #include "ggml-cuda/convert.cuh"
15
17
  #include "ggml-cuda/count-equal.cuh"
16
18
  #include "ggml-cuda/cpy.cuh"
@@ -35,6 +37,7 @@
35
37
  #include "ggml-cuda/ssm-scan.cuh"
36
38
  #include "ggml-cuda/sum.cuh"
37
39
  #include "ggml-cuda/sumrows.cuh"
40
+ #include "ggml-cuda/mean.cuh"
38
41
  #include "ggml-cuda/tsembd.cuh"
39
42
  #include "ggml-cuda/unary.cuh"
40
43
  #include "ggml-cuda/upscale.cuh"
@@ -47,6 +50,7 @@
47
50
  #include <atomic>
48
51
  #include <charconv>
49
52
  #include <cinttypes>
53
+ #include <condition_variable>
50
54
  #include <cstddef>
51
55
  #include <cstdint>
52
56
  #include <float.h>
@@ -54,9 +58,8 @@
54
58
  #include <map>
55
59
  #include <memory>
56
60
  #include <mutex>
57
- #include <stdint.h>
58
- #include <stdio.h>
59
61
  #include <stdarg.h>
62
+ #include <stdio.h>
60
63
  #include <stdlib.h>
61
64
  #include <string>
62
65
  #include <vector>
@@ -97,8 +100,7 @@ int ggml_cuda_get_device() {
97
100
  static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
98
101
  ggml_cuda_set_device(device);
99
102
  cudaError_t err;
100
- if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
101
- {
103
+ if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
102
104
  err = cudaMallocManaged(ptr, size);
103
105
  #if defined(GGML_USE_HIP)
104
106
  if (err == hipSuccess) {
@@ -116,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
116
118
  err = cudaMalloc(ptr, size);
117
119
  }
118
120
  #endif // defined(GGML_USE_HIP)
119
- }
120
- else
121
- {
121
+ } else {
122
122
  err = cudaMalloc(ptr, size);
123
123
  }
124
124
  return err;
@@ -243,10 +243,10 @@ static ggml_cuda_device_info ggml_cuda_init() {
243
243
 
244
244
  info.default_tensor_split[id] = total_vram;
245
245
  total_vram += prop.totalGlobalMem;
246
-
247
- info.devices[id].nsm = prop.multiProcessorCount;
248
- info.devices[id].smpb = prop.sharedMemPerBlock;
249
- info.devices[id].warp_size = prop.warpSize;
246
+ info.devices[id].integrated = prop.integrated;
247
+ info.devices[id].nsm = prop.multiProcessorCount;
248
+ info.devices[id].smpb = prop.sharedMemPerBlock;
249
+ info.devices[id].warp_size = prop.warpSize;
250
250
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
251
251
  info.devices[id].smpbo = prop.sharedMemPerBlock;
252
252
 
@@ -514,6 +514,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
514
514
  return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
515
515
  }
516
516
 
517
+ // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
518
+ // this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
519
+
520
+ static std::mutex ggml_cuda_lock;
521
+ static std::condition_variable ggml_cuda_lock_cv;
522
+ static std::atomic<int> ggml_cuda_lock_counter;
523
+
524
+ ggml_backend_cuda_context::~ggml_backend_cuda_context() {
525
+ std::unique_lock<std::mutex> lock(ggml_cuda_lock);
526
+ ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
527
+
528
+ if (copy_event != nullptr) {
529
+ CUDA_CHECK(cudaEventDestroy(copy_event));
530
+ }
531
+ for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
532
+ for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
533
+ if (streams[i][j] != nullptr) {
534
+ CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
535
+ }
536
+ }
537
+ if (cublas_handles[i] != nullptr) {
538
+ CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
539
+ }
540
+ }
541
+ }
542
+
543
+
517
544
  // cuda buffer
518
545
 
519
546
  struct ggml_backend_cuda_buffer_context {
@@ -615,9 +642,8 @@ static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
615
642
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
616
643
 
617
644
  ggml_cuda_set_device(ctx->device);
618
- CUDA_CHECK(cudaDeviceSynchronize());
619
- CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size));
620
- CUDA_CHECK(cudaDeviceSynchronize());
645
+ CUDA_CHECK(cudaMemsetAsync(ctx->dev_ptr, value, buffer->size, cudaStreamPerThread));
646
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
621
647
  }
622
648
 
623
649
  static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
@@ -1065,6 +1091,10 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_
1065
1091
  GGML_UNUSED(buft);
1066
1092
  }
1067
1093
 
1094
+ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
1095
+ return buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
1096
+ }
1097
+
1068
1098
  static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1069
1099
  CUDA_CHECK(cudaFreeHost(buffer->context));
1070
1100
  }
@@ -1140,7 +1170,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)(
1140
1170
  static cudaError_t ggml_cuda_cpy_tensor_2d(
1141
1171
  void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
1142
1172
 
1143
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
1144
1173
  const char * src_ptr = (const char *) src->data;
1145
1174
  char * dst_ptr = (char *) dst;
1146
1175
 
@@ -1198,9 +1227,12 @@ static void ggml_cuda_op_mul_mat_cublas(
1198
1227
 
1199
1228
  const int cc = ggml_cuda_info().devices[id].cc;
1200
1229
 
1230
+ const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
1231
+ (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
1232
+
1201
1233
  const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
1202
1234
 
1203
- if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1235
+ if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1204
1236
  ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
1205
1237
  if (src1->type != GGML_TYPE_BF16) {
1206
1238
  const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1228,7 +1260,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1228
1260
 
1229
1261
  const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
1230
1262
  to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1231
- } else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
1263
+ } else if (fast_fp16_hardware_available(cc) && use_fp16) {
1232
1264
  // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1233
1265
  ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
1234
1266
  if (src0->type != GGML_TYPE_F16) {
@@ -1423,8 +1455,6 @@ static void ggml_cuda_op_mul_mat(
1423
1455
  const int64_t nb2 = dst->nb[2];
1424
1456
  const int64_t nb3 = dst->nb[3];
1425
1457
 
1426
- GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
1427
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
1428
1458
  ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
1429
1459
  ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
1430
1460
 
@@ -1719,7 +1749,7 @@ static void ggml_cuda_op_mul_mat(
1719
1749
  }
1720
1750
 
1721
1751
  static __global__ void k_compute_batched_ptrs(
1722
- const half * src0_as_f16, const half * src1_as_f16, char * dst,
1752
+ const void * src0_as_f16, const void * src1_as_f16, char * dst,
1723
1753
  const void ** ptrs_src, void ** ptrs_dst,
1724
1754
  int64_t ne12, int64_t ne13,
1725
1755
  int64_t ne23,
@@ -1742,83 +1772,131 @@ static __global__ void k_compute_batched_ptrs(
1742
1772
  ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
1743
1773
  }
1744
1774
 
1745
- static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1775
+ // Type traits for mapping ggml types to CUDA/cuBLAS types
1776
+ template<ggml_type T>
1777
+ struct batched_mul_mat_traits;
1778
+
1779
+ template<>
1780
+ struct batched_mul_mat_traits<GGML_TYPE_F32> {
1781
+ using cuda_type = float;
1782
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1783
+ static inline const cudaDataType_t data_type = CUDA_R_32F;
1784
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
1785
+ static inline const float alpha = 1.0f;
1786
+ static inline const float beta = 0.0f;
1787
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
1788
+ static inline const void* get_beta() { static const float val = beta; return &val; }
1789
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
1790
+ };
1791
+
1792
+ template<>
1793
+ struct batched_mul_mat_traits<GGML_TYPE_BF16> {
1794
+ using cuda_type = nv_bfloat16;
1795
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1796
+ static inline const cudaDataType_t data_type = CUDA_R_16BF;
1797
+ static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
1798
+ static inline const float alpha = 1.0f;
1799
+ static inline const float beta = 0.0f;
1800
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
1801
+ static inline const void* get_beta() { static const float val = beta; return &val; }
1802
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
1803
+ };
1804
+
1805
+ template<>
1806
+ struct batched_mul_mat_traits<GGML_TYPE_F16> {
1807
+ using cuda_type = half;
1808
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
1809
+ static inline const cudaDataType_t data_type = CUDA_R_16F;
1810
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
1811
+ static inline const half alpha = 1.0;
1812
+ static inline const half beta = 0.0;
1813
+ static inline const void* get_alpha() { static const half val = alpha; return &val; }
1814
+ static inline const void* get_beta() { static const half val = beta; return &val; }
1815
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
1816
+ };
1817
+
1818
+ template<ggml_type src0_type>
1819
+ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1820
+ using traits = batched_mul_mat_traits<src0_type>;
1821
+ using cuda_t = typename traits::cuda_type;
1822
+
1746
1823
  GGML_ASSERT(!ggml_is_transposed(src0));
1747
1824
  GGML_ASSERT(!ggml_is_transposed(src1));
1748
-
1749
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
1750
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
1825
+ GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1826
+ GGML_ASSERT(src0->type == src0_type);
1827
+ GGML_ASSERT(ggml_is_contiguous(dst));
1751
1828
 
1752
1829
  // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
1753
1830
  // As long as dst is contiguous this does not matter though.
1754
- GGML_ASSERT(ggml_is_contiguous(dst));
1755
1831
 
1756
1832
  GGML_TENSOR_BINARY_OP_LOCALS
1757
1833
 
1758
1834
  const int64_t ne_dst = ggml_nelements(dst);
1759
-
1760
1835
  cudaStream_t main_stream = ctx.stream();
1761
-
1762
1836
  CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
1763
1837
 
1764
- const half * src0_f16 = (const half *) src0->data;
1765
1838
  float * dst_ddf = (float *) dst->data;
1766
-
1767
- const half * src1_f16 = (const half *) src1->data;
1768
1839
  const size_t ts_src1 = ggml_type_size(src1->type);
1769
1840
  GGML_ASSERT(nb10 == ts_src1);
1770
1841
  int64_t s11 = nb11 / ts_src1;
1771
1842
  int64_t s12 = nb12 / ts_src1;
1772
1843
  int64_t s13 = nb13 / ts_src1;
1773
- ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
1774
1844
 
1775
- // convert src1 to fp16
1776
- if (src1->type != GGML_TYPE_F16) {
1777
- const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
1778
- const int64_t ne_src1 = ggml_nelements(src1);
1779
- src1_f16_alloc.alloc(ne_src1);
1780
- GGML_ASSERT(to_fp16_cuda != nullptr);
1845
+ const cuda_t * src0_ptr = nullptr;
1846
+ const cuda_t * src1_ptr = nullptr;
1781
1847
 
1782
- to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1848
+ ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
1849
+ ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
1850
+
1851
+ // Handle src0
1852
+ src0_ptr = (const cuda_t *) src0->data;
1853
+
1854
+ // Handle src1 - convert if necessary
1855
+ if (src1->type == src0_type) {
1856
+ src1_ptr = (const cuda_t *) src1->data;
1857
+ } else {
1858
+ // Convert src1 to target type using traits conversion functions
1859
+ const int64_t ne_src1 = ggml_nelements(src1);
1860
+ src1_alloc.alloc(ne_src1);
1783
1861
 
1784
- src1_f16 = src1_f16_alloc.get();
1862
+ const auto convert_func = traits::get_nc_converter(src1->type);
1863
+ GGML_ASSERT(convert_func != nullptr);
1864
+ convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1865
+ src1_ptr = src1_alloc.get();
1785
1866
  s11 = ne10;
1786
1867
  s12 = ne11*s11;
1787
1868
  s13 = ne12*s12;
1788
1869
  }
1789
1870
 
1790
- ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
1871
+ // Setup destination buffer
1872
+ ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
1791
1873
  char * dst_t;
1792
-
1793
- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1794
- cudaDataType_t cu_data_type = CUDA_R_16F;
1795
-
1796
- // dst strides
1797
1874
  size_t nbd2 = dst->nb[2];
1798
1875
  size_t nbd3 = dst->nb[3];
1799
1876
 
1800
- const half alpha_f16 = 1.0f;
1801
- const half beta_f16 = 0.0f;
1802
-
1877
+ cublasComputeType_t cu_compute_type = traits::compute_type;
1878
+ cudaDataType_t cu_data_type = traits::data_type;
1879
+ cudaDataType_t cu_data_type_a = traits::data_type;
1880
+ cudaDataType_t cu_data_type_b = traits::data_type;
1881
+ const void * alpha = traits::get_alpha();
1882
+ const void * beta = traits::get_beta();
1803
1883
  const float alpha_f32 = 1.0f;
1804
- const float beta_f32 = 0.0f;
1805
-
1806
- const void * alpha = &alpha_f16;
1807
- const void * beta = &beta_f16;
1884
+ const float beta_f32 = 0.0f;
1808
1885
 
1809
1886
  if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1810
- dst_t = (char *) dst_f16.alloc(ne_dst);
1811
-
1812
- nbd2 /= sizeof(float) / sizeof(half);
1813
- nbd3 /= sizeof(float) / sizeof(half);
1887
+ if constexpr (src0_type == GGML_TYPE_F32) {
1888
+ dst_t = (char *) dst_ddf; // Direct F32 output
1889
+ } else {
1890
+ dst_t = (char *) dst_temp.alloc(ne_dst);
1891
+ nbd2 /= sizeof(float) / sizeof(cuda_t);
1892
+ nbd3 /= sizeof(float) / sizeof(cuda_t);
1893
+ }
1814
1894
  } else {
1815
1895
  dst_t = (char *) dst_ddf;
1816
-
1817
1896
  cu_compute_type = CUBLAS_COMPUTE_32F;
1818
- cu_data_type = CUDA_R_32F;
1819
-
1897
+ cu_data_type = CUDA_R_32F;
1820
1898
  alpha = &alpha_f32;
1821
- beta = &beta_f32;
1899
+ beta = &beta_f32;
1822
1900
  }
1823
1901
 
1824
1902
  int id = ggml_cuda_get_device();
@@ -1826,7 +1904,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1826
1904
  if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1827
1905
  cu_compute_type = CUBLAS_COMPUTE_32F;
1828
1906
  alpha = &alpha_f32;
1829
- beta = &beta_f32;
1907
+ beta = &beta_f32;
1830
1908
  }
1831
1909
 
1832
1910
  GGML_ASSERT(ne12 % ne02 == 0);
@@ -1836,35 +1914,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1836
1914
  const int64_t r2 = ne12/ne02;
1837
1915
  const int64_t r3 = ne13/ne03;
1838
1916
 
1839
- #if 0
1840
- // use cublasGemmEx
1841
- {
1842
- for (int i13 = 0; i13 < ne13; ++i13) {
1843
- for (int i12 = 0; i12 < ne12; ++i12) {
1844
- int i03 = i13 / r3;
1845
- int i02 = i12 / r2;
1846
-
1847
- CUBLAS_CHECK(
1848
- cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1849
- ne01, ne11, ne10,
1850
- alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half),
1851
- src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11,
1852
- beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
1853
- cu_compute_type,
1854
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1855
- }
1856
- }
1857
- }
1858
- #else
1859
1917
  if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1860
1918
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
1861
1919
  // use cublasGemmStridedBatchedEx
1862
1920
  CUBLAS_CHECK(
1863
1921
  cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1864
1922
  ne01, ne11, ne10,
1865
- alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
1866
- src1_f16, CUDA_R_16F, s11, s12, // strideB
1867
- beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1923
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1924
+ src1_ptr, cu_data_type_b, s11, s12, // strideB
1925
+ beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1868
1926
  ne12*ne13,
1869
1927
  cu_compute_type,
1870
1928
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1875,34 +1933,55 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1875
1933
  ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
1876
1934
  ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
1877
1935
 
1936
+ size_t src1_stride_size = sizeof(cuda_t);
1937
+
1878
1938
  dim3 block_dims(ne13, ne12);
1879
1939
  k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
1880
- src0_f16, src1_f16, dst_t,
1940
+ src0_ptr, src1_ptr, dst_t,
1881
1941
  ptrs_src.get(), ptrs_dst.get(),
1882
1942
  ne12, ne13,
1883
1943
  ne23,
1884
1944
  nb02, nb03,
1885
- src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
1886
- src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
1945
+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
1946
+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
1887
1947
  nbd2, nbd3,
1888
1948
  r2, r3);
1949
+
1889
1950
  CUDA_CHECK(cudaGetLastError());
1890
1951
 
1891
1952
  CUBLAS_CHECK(
1892
1953
  cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1893
1954
  ne01, ne11, ne10,
1894
- alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
1895
- (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11,
1896
- beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1955
+ alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
1956
+ (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
1957
+ beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
1897
1958
  ne23,
1898
1959
  cu_compute_type,
1899
1960
  CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1900
1961
  }
1901
- #endif
1902
1962
 
1903
- if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1904
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1905
- to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
1963
+ // Convert output back to F32 if needed
1964
+ if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1965
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
1966
+ to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
1967
+ }
1968
+ }
1969
+
1970
+ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1971
+ GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1972
+
1973
+ switch (src0->type) {
1974
+ case GGML_TYPE_F32:
1975
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1976
+ break;
1977
+ case GGML_TYPE_BF16:
1978
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1979
+ break;
1980
+ case GGML_TYPE_F16:
1981
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1982
+ break;
1983
+ default:
1984
+ GGML_ABORT("Unsupported type");
1906
1985
  }
1907
1986
  }
1908
1987
 
@@ -1916,16 +1995,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1916
1995
  && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
1917
1996
 
1918
1997
  bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
1919
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1920
- && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
1998
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
1921
1999
  bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
1922
2000
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1923
2001
  && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
1924
2002
  bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
1925
2003
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
1926
2004
 
1927
- bool any_gpus_with_slow_fp16 = false;
1928
- bool any_gpus_without_fp16_mma = false;
2005
+ bool any_gpus_with_slow_fp16 = false;
1929
2006
 
1930
2007
  if (split) {
1931
2008
  ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
@@ -1936,16 +2013,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1936
2013
  continue;
1937
2014
  }
1938
2015
 
1939
- const int cc = ggml_cuda_info().devices[id].cc;
1940
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1941
- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1942
- any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
2016
+ const int cc = ggml_cuda_info().devices[id].cc;
2017
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2018
+ use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
2019
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1943
2020
  }
1944
2021
  } else {
1945
- const int cc = ggml_cuda_info().devices[ctx.device].cc;
1946
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1947
- any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1948
- any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
2022
+ const int cc = ggml_cuda_info().devices[ctx.device].cc;
2023
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2024
+ use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
2025
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1949
2026
  }
1950
2027
 
1951
2028
  // debug helpers
@@ -1956,7 +2033,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1956
2033
  //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
1957
2034
  //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
1958
2035
 
1959
- if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
2036
+ //TODO update for generic tensor parallelism
2037
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2038
+ bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2039
+ bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2040
+ bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2041
+
2042
+ if (!split && use_mul_mat_vec) {
1960
2043
  // the custom F16 vector kernel can be used over batched cuBLAS GEMM
1961
2044
  // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
1962
2045
  ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
@@ -1964,8 +2047,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1964
2047
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
1965
2048
  } else if (!split && use_mul_mat_q) {
1966
2049
  ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
1967
- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1968
- !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2050
+ } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
2051
+ && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1969
2052
  // general KQ + KQV multi-batch without FlashAttention
1970
2053
  ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
1971
2054
  } else if (use_mul_mat_vec) {
@@ -2220,6 +2303,21 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2220
2303
  return false;
2221
2304
  }
2222
2305
  break;
2306
+ case GGML_OP_GLU:
2307
+ switch (ggml_get_glu_op(dst)) {
2308
+ case GGML_GLU_OP_REGLU:
2309
+ ggml_cuda_op_reglu(ctx, dst);
2310
+ break;
2311
+ case GGML_GLU_OP_GEGLU:
2312
+ ggml_cuda_op_geglu(ctx, dst);
2313
+ break;
2314
+ case GGML_GLU_OP_SWIGLU:
2315
+ ggml_cuda_op_swiglu(ctx, dst);
2316
+ break;
2317
+ default:
2318
+ return false;
2319
+ }
2320
+ break;
2223
2321
  case GGML_OP_NORM:
2224
2322
  ggml_cuda_op_norm(ctx, dst);
2225
2323
  break;
@@ -2310,6 +2408,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2310
2408
  case GGML_OP_IM2COL:
2311
2409
  ggml_cuda_op_im2col(ctx, dst);
2312
2410
  break;
2411
+ case GGML_OP_CONV_2D_DW:
2412
+ ggml_cuda_op_conv2d_dw(ctx, dst);
2413
+ break;
2414
+ case GGML_OP_CONV_TRANSPOSE_2D:
2415
+ ggml_cuda_conv_2d_transpose_p0(ctx, dst);
2416
+ break;
2313
2417
  case GGML_OP_CONV_TRANSPOSE_1D:
2314
2418
  ggml_cuda_op_conv_transpose_1d(ctx,dst);
2315
2419
  break;
@@ -2322,6 +2426,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2322
2426
  case GGML_OP_SUM_ROWS:
2323
2427
  ggml_cuda_op_sum_rows(ctx, dst);
2324
2428
  break;
2429
+ case GGML_OP_MEAN:
2430
+ ggml_cuda_op_mean(ctx, dst);
2431
+ break;
2325
2432
  case GGML_OP_SSM_CONV:
2326
2433
  ggml_cuda_op_ssm_conv(ctx, dst);
2327
2434
  break;
@@ -2641,6 +2748,8 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
2641
2748
 
2642
2749
  static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2643
2750
  bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
2751
+ // flag used to determine whether it is an integrated_gpu
2752
+ const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
2644
2753
 
2645
2754
  while (!graph_evaluated_or_captured) {
2646
2755
  // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
@@ -2659,10 +2768,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2659
2768
  if (node->src[j] != nullptr) {
2660
2769
  assert(node->src[j]->buffer);
2661
2770
  assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
2662
- ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
2771
+ ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
2663
2772
  }
2664
2773
  }
2665
- #endif
2774
+ #else
2775
+ GGML_UNUSED(integrated);
2776
+ #endif // NDEBUG
2666
2777
 
2667
2778
  bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
2668
2779
  if (!ok) {
@@ -2681,6 +2792,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2681
2792
 
2682
2793
  CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
2683
2794
  graph_evaluated_or_captured = true; // CUDA graph has been captured
2795
+
2796
+ std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2797
+ if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
2798
+ ggml_cuda_lock_cv.notify_all();
2799
+ }
2684
2800
  } else {
2685
2801
  graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
2686
2802
  }
@@ -2756,7 +2872,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
2756
2872
  }
2757
2873
  }
2758
2874
 
2759
- if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
2875
+ if (use_cuda_graph && cuda_graph_update_required) {
2876
+ // Start CUDA graph capture
2877
+ {
2878
+ std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2879
+ ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
2880
+ }
2881
+
2760
2882
  CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
2761
2883
  }
2762
2884
 
@@ -2989,14 +3111,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
2989
3111
  return false;
2990
3112
  }
2991
3113
  break;
3114
+ case GGML_OP_GLU:
3115
+ switch (ggml_get_glu_op(op)) {
3116
+ case GGML_GLU_OP_REGLU:
3117
+ case GGML_GLU_OP_GEGLU:
3118
+ case GGML_GLU_OP_SWIGLU:
3119
+ return ggml_is_contiguous_1(op->src[0]);
3120
+ default:
3121
+ return false;
3122
+ }
3123
+ break;
2992
3124
  case GGML_OP_MUL_MAT:
2993
3125
  case GGML_OP_MUL_MAT_ID:
2994
3126
  {
2995
3127
  struct ggml_tensor * a = op->src[0];
2996
3128
  struct ggml_tensor * b = op->src[1];
2997
- // for small weight matrices the active device can end up without any rows, don't use row split in those cases
2998
- // this avoids some edge cases (and the performance would not be good anyways)
2999
3129
  if (a->buffer && ggml_backend_buft_is_cuda_split(a->buffer->buft)) {
3130
+ if (a->ne[2] > 1 || a->ne[3] > 1) {
3131
+ return false;
3132
+ }
3133
+ // for small weight matrices the active device can end up without any rows, don't use row split in those cases
3134
+ // this avoids some edge cases (and the performance would not be good anyways)
3000
3135
  ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) a->buffer->buft->context;
3001
3136
  int64_t row_low;
3002
3137
  int64_t row_high;
@@ -3009,9 +3144,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3009
3144
  return false;
3010
3145
  }
3011
3146
  #ifdef GGML_USE_MUSA
3012
- if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3013
- !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3014
- return false;
3147
+ const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3148
+ if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3149
+ if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT &&
3150
+ a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
3151
+ return false;
3152
+ }
3153
+ if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID &&
3154
+ a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
3155
+ return false;
3156
+ }
3015
3157
  }
3016
3158
  #endif // GGML_USE_MUSA
3017
3159
  switch (a->type) {
@@ -3038,11 +3180,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3038
3180
  case GGML_TYPE_IQ4_NL:
3039
3181
  case GGML_TYPE_IQ4_XS:
3040
3182
  case GGML_TYPE_BF16:
3041
- #ifdef GGML_USE_MUSA
3042
- if (a->type == GGML_TYPE_Q3_K) {
3043
- return false;
3044
- }
3045
- #endif // GGML_USE_MUSA
3046
3183
  return true;
3047
3184
  default:
3048
3185
  return false;
@@ -3202,9 +3339,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3202
3339
  return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
3203
3340
  }
3204
3341
  case GGML_OP_IM2COL:
3342
+ case GGML_OP_CONV_2D_DW:
3343
+ case GGML_OP_CONV_TRANSPOSE_2D:
3205
3344
  case GGML_OP_POOL_2D:
3206
3345
  case GGML_OP_SUM:
3207
3346
  case GGML_OP_SUM_ROWS:
3347
+ case GGML_OP_MEAN:
3208
3348
  case GGML_OP_ARGSORT:
3209
3349
  case GGML_OP_ACC:
3210
3350
  return true;
@@ -3263,7 +3403,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3263
3403
  }
3264
3404
 
3265
3405
  static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
3266
- return (ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev;
3406
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
3407
+ const bool integrated = ggml_cuda_info().devices[dev_ctx->device].integrated;
3408
+ return (((ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev) || (integrated && ggml_backend_buft_is_cuda_host(buft)));
3267
3409
  }
3268
3410
 
3269
3411
  static int64_t get_op_batch_size(const ggml_tensor * op) {