whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -0,0 +1,19 @@
1
+ #include "mean.cuh"
2
+
3
+ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
4
+ const ggml_tensor * src0 = dst->src[0];
5
+ const float * src0_d = (const float *) src0->data;
6
+ float * dst_d = (float *) dst->data;
7
+ cudaStream_t stream = ctx.stream();
8
+
9
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
10
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
11
+ GGML_ASSERT(ggml_is_contiguous(src0));
12
+
13
+ const int64_t ncols = src0->ne[0];
14
+ const int64_t nrows = ggml_nrows(src0);
15
+
16
+ const dim3 block_dims(WARP_SIZE, 1, 1);
17
+ const dim3 block_nums(nrows, 1, 1);
18
+ reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
19
+ }
@@ -0,0 +1,3 @@
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -2,25 +2,26 @@
2
2
  #include "common.cuh"
3
3
  #include "mmv.cuh"
4
4
 
5
- template <typename T, typename type_acc, int block_size>
5
+ template <typename T, typename type_acc, int ncols_dst, int block_size>
6
6
  static __global__ void mul_mat_vec(
7
7
  const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
8
- const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
9
- const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
10
- const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
11
- const int64_t row = blockIdx.x;
12
- const int64_t channel_dst = blockIdx.y;
13
- const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
14
- const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst;
15
- const int64_t sample_dst = blockIdx.z;
16
- const int64_t sample_x = sample_dst / sample_ratio;
17
- const int64_t sample_y = sample_dst;
18
- const int tid = threadIdx.x;
8
+ const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
9
+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
10
+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
11
+ const int row = blockIdx.x;
12
+ const int channel_dst = blockIdx.y;
13
+ const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
14
+ const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
15
+ const int sample_dst = blockIdx.z;
16
+ const int sample_x = sample_dst / sample_ratio;
17
+ const int sample_y = sample_dst;
18
+ const int tid = threadIdx.x;
19
+
19
20
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
20
21
 
21
- x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
22
- y += sample_y *stride_sample_y + channel_y *stride_channel_y;
23
- dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
22
+ x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
23
+ y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
24
+ dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
24
25
 
25
26
  const float2 * y2 = (const float2 *) y;
26
27
 
@@ -34,81 +35,108 @@ static __global__ void mul_mat_vec(
34
35
  __syncthreads();
35
36
  }
36
37
 
37
- float sumf = 0.0f;
38
+ float sumf[ncols_dst] = {0.0f};
38
39
 
39
40
  if constexpr (std::is_same<T, float>::value) {
40
41
  const float2 * x2 = (const float2 *) x;
41
42
 
42
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
43
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
43
44
  const float2 tmpx = x2[col2];
44
- const float2 tmpy = y2[col2];
45
- sumf += tmpx.x*tmpy.x;
46
- sumf += tmpx.y*tmpy.y;
45
+
46
+ #pragma unroll
47
+ for (int j = 0; j < ncols_dst; ++j) {
48
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
49
+ sumf[j] += tmpx.x*tmpy.x;
50
+ sumf[j] += tmpx.y*tmpy.y;
51
+ }
47
52
  }
48
53
  } else if constexpr (std::is_same<T, half>::value) {
49
54
  const half2 * x2 = (const half2 *) x;
50
55
 
51
56
  if (std::is_same<type_acc, float>::value) {
52
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
57
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
53
58
  const float2 tmpx = __half22float2(x2[col2]);
54
- const float2 tmpy = y2[col2];
55
- sumf += tmpx.x * tmpy.x;
56
- sumf += tmpx.y * tmpy.y;
59
+
60
+ #pragma unroll
61
+ for (int j = 0; j < ncols_dst; ++j) {
62
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
63
+ sumf[j] += tmpx.x * tmpy.x;
64
+ sumf[j] += tmpx.y * tmpy.y;
65
+ }
57
66
  }
58
67
  } else {
59
68
  #ifdef FP16_AVAILABLE
60
- half2 sumh2 = make_half2(0.0f, 0.0f);
69
+ half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
70
+
71
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
72
+ const half2 tmpx = x2[col2];
61
73
 
62
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
63
- const float2 tmp = y2[col2];
64
- sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
74
+ #pragma unroll
75
+ for (int j = 0; j < ncols_dst; ++j) {
76
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
77
+ sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
78
+ }
65
79
  }
66
80
 
67
- sumf = __low2float(sumh2) + __high2float(sumh2);
81
+ #pragma unroll
82
+ for (int j = 0; j < ncols_dst; ++j) {
83
+ sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
84
+ }
68
85
  #else
69
86
  NO_DEVICE_CODE;
70
87
  #endif // FP16_AVAILABLE
71
88
  }
72
89
  } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
73
90
  const int * x2 = (const int *) x;
74
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
75
- const int tmpx = x2[col2];
76
- const float2 tmpy = y2[col2];
77
- sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
78
- sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
91
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
92
+ const int tmpx = x2[col2];
93
+ #pragma unroll
94
+ for (int j = 0; j < ncols_dst; ++j) {
95
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
96
+ sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
97
+ sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
98
+ }
79
99
  }
80
100
  } else {
81
101
  static_assert(std::is_same<T, void>::value, "unsupported type");
82
102
  }
83
103
 
84
- sumf = warp_reduce_sum<warp_size>(sumf);
104
+ #pragma unroll
105
+ for (int j = 0; j < ncols_dst; ++j) {
106
+ sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
85
107
 
86
- if (block_size > warp_size) {
87
- buf_iw[tid/warp_size] = sumf;
88
- __syncthreads();
89
- if (tid >= warp_size) {
90
- return;
108
+ if (block_size > warp_size) {
109
+ buf_iw[tid/warp_size] = sumf[j];
110
+ __syncthreads();
111
+ if (tid < warp_size) {
112
+ sumf[j] = buf_iw[tid];
113
+ sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
114
+ }
115
+ if (j < ncols_dst) {
116
+ __syncthreads();
117
+ }
91
118
  }
92
- sumf = buf_iw[tid];
93
- sumf = warp_reduce_sum<warp_size>(sumf);
94
119
  }
95
120
 
96
- if (tid != 0) {
121
+ if (tid >= ncols_dst) {
97
122
  return;
98
123
  }
99
124
 
100
- dst[row] = sumf;
125
+ dst[tid*stride_col_dst + row] = sumf[tid];
101
126
  }
102
127
 
103
- template <typename T, typename type_acc>
128
+ template <typename T, typename type_acc, int ncols_dst>
104
129
  static void launch_mul_mat_vec_cuda(
105
130
  const T * x, const float * y, const int32_t * ids, float * dst,
106
- const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
131
+ const int64_t ncols, const int64_t nrows,
132
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
133
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
107
134
  const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
108
135
  const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
109
136
  cudaStream_t stream) {
110
- GGML_ASSERT(ncols % 2 == 0);
111
- GGML_ASSERT(stride_row % 2 == 0);
137
+ GGML_ASSERT(ncols % 2 == 0);
138
+ GGML_ASSERT(stride_row % 2 == 0);
139
+ GGML_ASSERT(stride_col_y % 2 == 0);
112
140
  GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
113
141
  GGML_ASSERT( nsamples_dst % nsamples_x == 0);
114
142
  const int64_t channel_ratio = nchannels_dst / nchannels_x;
@@ -138,44 +166,52 @@ static void launch_mul_mat_vec_cuda(
138
166
  const dim3 block_dims(block_size_best, 1, 1);
139
167
  switch (block_size_best) {
140
168
  case 32: {
141
- mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
142
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
143
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
169
+ mul_mat_vec<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, smem, stream>>>
170
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
171
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
172
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
144
173
  } break;
145
174
  case 64: {
146
- mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
147
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
148
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
175
+ mul_mat_vec<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, smem, stream>>>
176
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
177
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
178
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
149
179
  } break;
150
180
  case 96: {
151
- mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
152
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
153
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
181
+ mul_mat_vec<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, smem, stream>>>
182
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
183
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
184
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
154
185
  } break;
155
186
  case 128: {
156
- mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
157
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
158
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
187
+ mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
188
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
189
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
190
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
159
191
  } break;
160
192
  case 160: {
161
- mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
162
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
163
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
193
+ mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
194
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
195
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
196
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
164
197
  } break;
165
198
  case 192: {
166
- mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
167
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
168
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
199
+ mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
200
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
201
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
202
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
169
203
  } break;
170
204
  case 224: {
171
- mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
172
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
173
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
205
+ mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
206
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
207
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
208
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
174
209
  } break;
175
210
  case 256: {
176
- mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
177
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
178
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
211
+ mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
212
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
213
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
214
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
179
215
  } break;
180
216
  default: {
181
217
  GGML_ABORT("fatal error");
@@ -183,23 +219,91 @@ static void launch_mul_mat_vec_cuda(
183
219
  }
184
220
  }
185
221
 
222
+ template <typename T, typename type_acc>
223
+ static void mul_mat_vec_cuda_switch_ncols_dst(
224
+ const T * x, const float * y, const int32_t * ids, float * dst,
225
+ const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
226
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
227
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
228
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
229
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
230
+ cudaStream_t stream) {
231
+ switch (ncols_dst) {
232
+ case 1:
233
+ launch_mul_mat_vec_cuda<T, type_acc, 1>
234
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
235
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
236
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
237
+ break;
238
+ case 2:
239
+ launch_mul_mat_vec_cuda<T, type_acc, 2>
240
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
241
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
242
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
243
+ break;
244
+ case 3:
245
+ launch_mul_mat_vec_cuda<T, type_acc, 3>
246
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
247
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
248
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
249
+ break;
250
+ case 4:
251
+ launch_mul_mat_vec_cuda<T, type_acc, 4>
252
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
253
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
254
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
255
+ break;
256
+ case 5:
257
+ launch_mul_mat_vec_cuda<T, type_acc, 5>
258
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
259
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
260
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
261
+ break;
262
+ case 6:
263
+ launch_mul_mat_vec_cuda<T, type_acc, 6>
264
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
265
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
266
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
267
+ break;
268
+ case 7:
269
+ launch_mul_mat_vec_cuda<T, type_acc, 7>
270
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
271
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
272
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
273
+ break;
274
+ case 8:
275
+ launch_mul_mat_vec_cuda<T, type_acc, 8>
276
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
277
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
278
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
279
+ break;
280
+ default:
281
+ GGML_ABORT("fatal error");
282
+ break;
283
+ }
284
+ }
285
+
186
286
  template<typename T>
187
287
  static void mul_mat_vec_cuda(
188
288
  const T * x, const float * y, const int32_t * ids, float * dst,
189
- const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
289
+ const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
290
+ const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
291
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
190
292
  const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
191
293
  const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
192
294
  enum ggml_prec prec, cudaStream_t stream) {
193
295
  if constexpr(std::is_same<T, half>::value) {
194
296
  if (prec == GGML_PREC_DEFAULT) {
195
- launch_mul_mat_vec_cuda<T, half>
196
- (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
297
+ mul_mat_vec_cuda_switch_ncols_dst<T, half>
298
+ (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
299
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
197
300
  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
198
301
  return;
199
302
  }
200
303
  }
201
- launch_mul_mat_vec_cuda<T, float>
202
- (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
304
+ mul_mat_vec_cuda_switch_ncols_dst<T, float>
305
+ (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
306
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
203
307
  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
204
308
  }
205
309
 
@@ -246,24 +350,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
246
350
  const int64_t stride_channel_dst = ids ? s1 : s2;
247
351
  const int64_t stride_channel_y = ids ? s11 : s12;
248
352
 
249
- GGML_ASSERT(ncols_dst == 1);
353
+ GGML_ASSERT(!ids || ncols_dst == 1);
250
354
 
251
355
  switch (src0->type) {
252
356
  case GGML_TYPE_F32: {
253
357
  const float * src0_d = (const float *) src0->data;
254
- mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
358
+ mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
255
359
  ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
256
360
  ne03, ne3, s03, s13, s3, prec, ctx.stream());
257
361
  } break;
258
362
  case GGML_TYPE_F16: {
259
363
  const half * src0_d = (const half *) src0->data;
260
- mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
364
+ mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
261
365
  ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
262
366
  ne03, ne3, s03, s13, s3, prec, ctx.stream());
263
367
  } break;
264
368
  case GGML_TYPE_BF16: {
265
369
  const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
266
- mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
370
+ mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
267
371
  ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
268
372
  ne03, ne3, s03, s13, s3, prec, ctx.stream());
269
373
  } break;
@@ -282,16 +386,19 @@ void ggml_cuda_op_mul_mat_vec(
282
386
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
283
387
 
284
388
  const int64_t ne00 = src0->ne[0];
389
+ const int64_t ne10 = src1->ne[0];
390
+ const int64_t ne0 = dst->ne[0];
285
391
  const int64_t row_diff = row_high - row_low;
286
392
 
287
- GGML_ASSERT(src1_ncols == 1);
288
-
289
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
393
+ const int id = ggml_cuda_get_device();
394
+ const int cc = ggml_cuda_info().devices[id].cc;
290
395
  const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
291
396
 
292
397
 
293
398
  // ggml_cuda_op provides single, contiguous matrices
294
399
  const int64_t stride_row = ne00;
400
+ const int64_t stride_col_y = ne10;
401
+ const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
295
402
  const int64_t nchannels_x = 1;
296
403
  const int64_t nchannels_y = 1;
297
404
  const int64_t nchannels_dst = 1;
@@ -307,19 +414,19 @@ void ggml_cuda_op_mul_mat_vec(
307
414
  switch (src0->type) {
308
415
  case GGML_TYPE_F32: {
309
416
  const float * src0_d = (const float *) src0_dd_i;
310
- mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
417
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
311
418
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
312
419
  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
313
420
  } break;
314
421
  case GGML_TYPE_F16: {
315
422
  const half * src0_d = (const half *) src0_dd_i;
316
- mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
423
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
317
424
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
318
425
  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
319
426
  } break;
320
427
  case GGML_TYPE_BF16: {
321
428
  const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
322
- mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
429
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
323
430
  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
324
431
  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
325
432
  } break;
@@ -334,3 +441,66 @@ void ggml_cuda_op_mul_mat_vec(
334
441
  GGML_UNUSED(src1_ncols);
335
442
  GGML_UNUSED(src1_padded_row_size);
336
443
  }
444
+
445
+ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
446
+ if (src0_ne[0] % 2 != 0) {
447
+ return false;
448
+ }
449
+ switch (type) {
450
+ case GGML_TYPE_F32:
451
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
452
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
453
+ return ne11 <= 8;
454
+ }
455
+ if (cc >= GGML_CUDA_CC_TURING) {
456
+ return ne11 <= 4;
457
+ }
458
+ return ne11 <= 3;
459
+ } else if (GGML_CUDA_CC_IS_AMD(cc)) {
460
+ if (fp32_mma_hardware_available(cc)) {
461
+ return ne11 <= 3;
462
+ }
463
+ return ne11 <= 8;
464
+ }
465
+ return ne11 <= 8;
466
+ case GGML_TYPE_F16:
467
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
468
+ const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
469
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
470
+ return src0_small && ne11 <= 4;
471
+ }
472
+ if (fp16_mma_hardware_available(cc)) {
473
+ return src0_small && ne11 <= 3;
474
+ }
475
+ return ne11 <= 8;
476
+ } else if (GGML_CUDA_CC_IS_AMD(cc)) {
477
+ if (fp16_mma_hardware_available(cc)) {
478
+ if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
479
+ return ne11 <= 5;
480
+ }
481
+ return ne11 <= 2;
482
+ }
483
+ return ne11 <= 8;
484
+ }
485
+ return ne11 <= 8;
486
+ case GGML_TYPE_BF16:
487
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
488
+ const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
489
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
490
+ return src0_small && ne11 <= 4;
491
+ }
492
+ if (bf16_mma_hardware_available(cc)) {
493
+ return src0_small && ne11 <= 3;
494
+ }
495
+ return ne11 <= 8;
496
+ } else if (GGML_CUDA_CC_IS_AMD(cc)) {
497
+ if (bf16_mma_hardware_available(cc)) {
498
+ return ne11 <= 3;
499
+ }
500
+ return ne11 <= 8;
501
+ }
502
+ return ne11 <= 8;
503
+ default:
504
+ return false;
505
+ }
506
+ }
@@ -1,8 +1,5 @@
1
1
  #include "common.cuh"
2
2
 
3
- // maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
4
- #define MMV_MAX_ROWS 512
5
-
6
3
  void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
7
4
 
8
5
  void ggml_cuda_op_mul_mat_vec(
@@ -10,3 +7,5 @@ void ggml_cuda_op_mul_mat_vec(
10
7
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
11
8
  const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
12
9
  const int64_t src1_padded_row_size, cudaStream_t stream);
10
+
11
+ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
@@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
10
10
  float * __restrict__ dst, const int64_t L) {
11
11
  GGML_UNUSED(src1_nb0);
12
12
  GGML_UNUSED(src2_nb0);
13
+
14
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
13
15
  const int bidx = blockIdx.x; // split along B
14
16
  const int bidy = blockIdx.y; // split along D
15
17
  const int tid = threadIdx.x;
@@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
44
46
  if (N == 16) {
45
47
  #pragma unroll
46
48
  for (size_t i = 0; i < splitD / 4; i += 2) {
47
- float value = A_block[(wid * warpSize + i) * stride_A + wtid];
49
+ float value = A_block[(wid * warp_size + i) * stride_A + wtid];
48
50
  // todo: bank conflict
49
51
  // I am always confused with how to use the swizzling method to solve
50
52
  // bank conflit. Hoping somebody can tell me.
51
- smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
53
+ smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
52
54
  }
53
55
  #pragma unroll
54
56
  for (size_t i = 0; i < splitD / 4; i += 2) {
55
- float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
56
- smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
57
+ float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
58
+ smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
57
59
  }
58
60
  }
59
61
 
@@ -1,25 +1,9 @@
1
1
  #include "sumrows.cuh"
2
2
 
3
- static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
4
- const int row = blockIdx.x;
5
- const int col = threadIdx.x;
6
-
7
- float sum = 0.0f;
8
- for (int i = col; i < ncols; i += blockDim.x) {
9
- sum += x[row * ncols + i];
10
- }
11
-
12
- sum = warp_reduce_sum(sum);
13
-
14
- if (col == 0) {
15
- dst[row] = sum;
16
- }
17
- }
18
-
19
3
  void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
20
4
  const dim3 block_dims(WARP_SIZE, 1, 1);
21
5
  const dim3 block_nums(nrows, 1, 1);
22
- k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
6
+ reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
23
7
  }
24
8
 
25
9
  void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
35
19
  const int64_t ncols = src0->ne[0];
36
20
  const int64_t nrows = ggml_nrows(src0);
37
21
 
38
- sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
22
+ const dim3 block_dims(WARP_SIZE, 1, 1);
23
+ const dim3 block_nums(nrows, 1, 1);
24
+
25
+ reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
39
26
  }
@@ -1,5 +1,4 @@
1
1
  #include "common.cuh"
2
2
 
3
3
  void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
4
-
5
4
  void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);