whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -0,0 +1,19 @@
|
|
1
|
+
#include "mean.cuh"
|
2
|
+
|
3
|
+
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
4
|
+
const ggml_tensor * src0 = dst->src[0];
|
5
|
+
const float * src0_d = (const float *) src0->data;
|
6
|
+
float * dst_d = (float *) dst->data;
|
7
|
+
cudaStream_t stream = ctx.stream();
|
8
|
+
|
9
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
10
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
11
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
12
|
+
|
13
|
+
const int64_t ncols = src0->ne[0];
|
14
|
+
const int64_t nrows = ggml_nrows(src0);
|
15
|
+
|
16
|
+
const dim3 block_dims(WARP_SIZE, 1, 1);
|
17
|
+
const dim3 block_nums(nrows, 1, 1);
|
18
|
+
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
19
|
+
}
|
@@ -2,25 +2,26 @@
|
|
2
2
|
#include "common.cuh"
|
3
3
|
#include "mmv.cuh"
|
4
4
|
|
5
|
-
template <typename T, typename type_acc, int block_size>
|
5
|
+
template <typename T, typename type_acc, int ncols_dst, int block_size>
|
6
6
|
static __global__ void mul_mat_vec(
|
7
7
|
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
8
|
-
const
|
9
|
-
const
|
10
|
-
const
|
11
|
-
const
|
12
|
-
const
|
13
|
-
const
|
14
|
-
const
|
15
|
-
const
|
16
|
-
const
|
17
|
-
const
|
18
|
-
const int
|
8
|
+
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
9
|
+
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
10
|
+
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
11
|
+
const int row = blockIdx.x;
|
12
|
+
const int channel_dst = blockIdx.y;
|
13
|
+
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
14
|
+
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
15
|
+
const int sample_dst = blockIdx.z;
|
16
|
+
const int sample_x = sample_dst / sample_ratio;
|
17
|
+
const int sample_y = sample_dst;
|
18
|
+
const int tid = threadIdx.x;
|
19
|
+
|
19
20
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
20
21
|
|
21
|
-
x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
22
|
-
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
|
23
|
-
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
|
22
|
+
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
23
|
+
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
24
|
+
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
24
25
|
|
25
26
|
const float2 * y2 = (const float2 *) y;
|
26
27
|
|
@@ -34,81 +35,108 @@ static __global__ void mul_mat_vec(
|
|
34
35
|
__syncthreads();
|
35
36
|
}
|
36
37
|
|
37
|
-
float sumf = 0.0f;
|
38
|
+
float sumf[ncols_dst] = {0.0f};
|
38
39
|
|
39
40
|
if constexpr (std::is_same<T, float>::value) {
|
40
41
|
const float2 * x2 = (const float2 *) x;
|
41
42
|
|
42
|
-
for (
|
43
|
+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
43
44
|
const float2 tmpx = x2[col2];
|
44
|
-
|
45
|
-
|
46
|
-
|
45
|
+
|
46
|
+
#pragma unroll
|
47
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
48
|
+
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
49
|
+
sumf[j] += tmpx.x*tmpy.x;
|
50
|
+
sumf[j] += tmpx.y*tmpy.y;
|
51
|
+
}
|
47
52
|
}
|
48
53
|
} else if constexpr (std::is_same<T, half>::value) {
|
49
54
|
const half2 * x2 = (const half2 *) x;
|
50
55
|
|
51
56
|
if (std::is_same<type_acc, float>::value) {
|
52
|
-
for (
|
57
|
+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
53
58
|
const float2 tmpx = __half22float2(x2[col2]);
|
54
|
-
|
55
|
-
|
56
|
-
|
59
|
+
|
60
|
+
#pragma unroll
|
61
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
62
|
+
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
63
|
+
sumf[j] += tmpx.x * tmpy.x;
|
64
|
+
sumf[j] += tmpx.y * tmpy.y;
|
65
|
+
}
|
57
66
|
}
|
58
67
|
} else {
|
59
68
|
#ifdef FP16_AVAILABLE
|
60
|
-
half2 sumh2 =
|
69
|
+
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
|
70
|
+
|
71
|
+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
72
|
+
const half2 tmpx = x2[col2];
|
61
73
|
|
62
|
-
|
63
|
-
|
64
|
-
|
74
|
+
#pragma unroll
|
75
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
76
|
+
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
77
|
+
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
|
78
|
+
}
|
65
79
|
}
|
66
80
|
|
67
|
-
|
81
|
+
#pragma unroll
|
82
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
83
|
+
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
|
84
|
+
}
|
68
85
|
#else
|
69
86
|
NO_DEVICE_CODE;
|
70
87
|
#endif // FP16_AVAILABLE
|
71
88
|
}
|
72
89
|
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
73
90
|
const int * x2 = (const int *) x;
|
74
|
-
for (
|
75
|
-
const int
|
76
|
-
|
77
|
-
|
78
|
-
|
91
|
+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
92
|
+
const int tmpx = x2[col2];
|
93
|
+
#pragma unroll
|
94
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
95
|
+
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
96
|
+
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
97
|
+
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
98
|
+
}
|
79
99
|
}
|
80
100
|
} else {
|
81
101
|
static_assert(std::is_same<T, void>::value, "unsupported type");
|
82
102
|
}
|
83
103
|
|
84
|
-
|
104
|
+
#pragma unroll
|
105
|
+
for (int j = 0; j < ncols_dst; ++j) {
|
106
|
+
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
85
107
|
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
108
|
+
if (block_size > warp_size) {
|
109
|
+
buf_iw[tid/warp_size] = sumf[j];
|
110
|
+
__syncthreads();
|
111
|
+
if (tid < warp_size) {
|
112
|
+
sumf[j] = buf_iw[tid];
|
113
|
+
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
114
|
+
}
|
115
|
+
if (j < ncols_dst) {
|
116
|
+
__syncthreads();
|
117
|
+
}
|
91
118
|
}
|
92
|
-
sumf = buf_iw[tid];
|
93
|
-
sumf = warp_reduce_sum<warp_size>(sumf);
|
94
119
|
}
|
95
120
|
|
96
|
-
if (tid
|
121
|
+
if (tid >= ncols_dst) {
|
97
122
|
return;
|
98
123
|
}
|
99
124
|
|
100
|
-
dst[row] = sumf;
|
125
|
+
dst[tid*stride_col_dst + row] = sumf[tid];
|
101
126
|
}
|
102
127
|
|
103
|
-
template <typename T, typename type_acc>
|
128
|
+
template <typename T, typename type_acc, int ncols_dst>
|
104
129
|
static void launch_mul_mat_vec_cuda(
|
105
130
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
106
|
-
const int64_t ncols, const int64_t nrows,
|
131
|
+
const int64_t ncols, const int64_t nrows,
|
132
|
+
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
133
|
+
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
107
134
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
108
135
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
109
136
|
cudaStream_t stream) {
|
110
|
-
GGML_ASSERT(ncols
|
111
|
-
GGML_ASSERT(stride_row
|
137
|
+
GGML_ASSERT(ncols % 2 == 0);
|
138
|
+
GGML_ASSERT(stride_row % 2 == 0);
|
139
|
+
GGML_ASSERT(stride_col_y % 2 == 0);
|
112
140
|
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
113
141
|
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
114
142
|
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
@@ -138,44 +166,52 @@ static void launch_mul_mat_vec_cuda(
|
|
138
166
|
const dim3 block_dims(block_size_best, 1, 1);
|
139
167
|
switch (block_size_best) {
|
140
168
|
case 32: {
|
141
|
-
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
142
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
143
|
-
|
169
|
+
mul_mat_vec<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, smem, stream>>>
|
170
|
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
171
|
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
172
|
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
144
173
|
} break;
|
145
174
|
case 64: {
|
146
|
-
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
147
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
148
|
-
|
175
|
+
mul_mat_vec<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, smem, stream>>>
|
176
|
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
177
|
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
178
|
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
149
179
|
} break;
|
150
180
|
case 96: {
|
151
|
-
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
152
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
153
|
-
|
181
|
+
mul_mat_vec<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, smem, stream>>>
|
182
|
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
183
|
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
184
|
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
154
185
|
} break;
|
155
186
|
case 128: {
|
156
|
-
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
157
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
158
|
-
|
187
|
+
mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
|
188
|
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
189
|
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
190
|
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
159
191
|
} break;
|
160
192
|
case 160: {
|
161
|
-
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
162
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
163
|
-
|
193
|
+
mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
|
194
|
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
195
|
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
196
|
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
164
197
|
} break;
|
165
198
|
case 192: {
|
166
|
-
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
167
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
168
|
-
|
199
|
+
mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
|
200
|
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
201
|
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
202
|
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
169
203
|
} break;
|
170
204
|
case 224: {
|
171
|
-
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
172
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
173
|
-
|
205
|
+
mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
|
206
|
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
207
|
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
208
|
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
174
209
|
} break;
|
175
210
|
case 256: {
|
176
|
-
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
177
|
-
(x, y, ids, dst, ncols/2, nchannels_y, stride_row,
|
178
|
-
|
211
|
+
mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
|
212
|
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
213
|
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
214
|
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
179
215
|
} break;
|
180
216
|
default: {
|
181
217
|
GGML_ABORT("fatal error");
|
@@ -183,23 +219,91 @@ static void launch_mul_mat_vec_cuda(
|
|
183
219
|
}
|
184
220
|
}
|
185
221
|
|
222
|
+
template <typename T, typename type_acc>
|
223
|
+
static void mul_mat_vec_cuda_switch_ncols_dst(
|
224
|
+
const T * x, const float * y, const int32_t * ids, float * dst,
|
225
|
+
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
226
|
+
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
227
|
+
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
228
|
+
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
229
|
+
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
230
|
+
cudaStream_t stream) {
|
231
|
+
switch (ncols_dst) {
|
232
|
+
case 1:
|
233
|
+
launch_mul_mat_vec_cuda<T, type_acc, 1>
|
234
|
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
235
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
236
|
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
237
|
+
break;
|
238
|
+
case 2:
|
239
|
+
launch_mul_mat_vec_cuda<T, type_acc, 2>
|
240
|
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
241
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
242
|
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
243
|
+
break;
|
244
|
+
case 3:
|
245
|
+
launch_mul_mat_vec_cuda<T, type_acc, 3>
|
246
|
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
247
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
248
|
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
249
|
+
break;
|
250
|
+
case 4:
|
251
|
+
launch_mul_mat_vec_cuda<T, type_acc, 4>
|
252
|
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
253
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
254
|
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
255
|
+
break;
|
256
|
+
case 5:
|
257
|
+
launch_mul_mat_vec_cuda<T, type_acc, 5>
|
258
|
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
259
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
260
|
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
261
|
+
break;
|
262
|
+
case 6:
|
263
|
+
launch_mul_mat_vec_cuda<T, type_acc, 6>
|
264
|
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
265
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
266
|
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
267
|
+
break;
|
268
|
+
case 7:
|
269
|
+
launch_mul_mat_vec_cuda<T, type_acc, 7>
|
270
|
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
271
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
272
|
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
273
|
+
break;
|
274
|
+
case 8:
|
275
|
+
launch_mul_mat_vec_cuda<T, type_acc, 8>
|
276
|
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
277
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
278
|
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
279
|
+
break;
|
280
|
+
default:
|
281
|
+
GGML_ABORT("fatal error");
|
282
|
+
break;
|
283
|
+
}
|
284
|
+
}
|
285
|
+
|
186
286
|
template<typename T>
|
187
287
|
static void mul_mat_vec_cuda(
|
188
288
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
189
|
-
const int64_t ncols, const int64_t nrows, const int64_t
|
289
|
+
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
290
|
+
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
|
291
|
+
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
190
292
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
191
293
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
192
294
|
enum ggml_prec prec, cudaStream_t stream) {
|
193
295
|
if constexpr(std::is_same<T, half>::value) {
|
194
296
|
if (prec == GGML_PREC_DEFAULT) {
|
195
|
-
|
196
|
-
(x, y, ids, dst, ncols, nrows,
|
297
|
+
mul_mat_vec_cuda_switch_ncols_dst<T, half>
|
298
|
+
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
299
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
197
300
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
198
301
|
return;
|
199
302
|
}
|
200
303
|
}
|
201
|
-
|
202
|
-
(x, y, ids, dst, ncols, nrows,
|
304
|
+
mul_mat_vec_cuda_switch_ncols_dst<T, float>
|
305
|
+
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
306
|
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
203
307
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
204
308
|
}
|
205
309
|
|
@@ -246,24 +350,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
|
246
350
|
const int64_t stride_channel_dst = ids ? s1 : s2;
|
247
351
|
const int64_t stride_channel_y = ids ? s11 : s12;
|
248
352
|
|
249
|
-
GGML_ASSERT(ncols_dst == 1);
|
353
|
+
GGML_ASSERT(!ids || ncols_dst == 1);
|
250
354
|
|
251
355
|
switch (src0->type) {
|
252
356
|
case GGML_TYPE_F32: {
|
253
357
|
const float * src0_d = (const float *) src0->data;
|
254
|
-
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
358
|
+
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
255
359
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
256
360
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
257
361
|
} break;
|
258
362
|
case GGML_TYPE_F16: {
|
259
363
|
const half * src0_d = (const half *) src0->data;
|
260
|
-
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
364
|
+
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
261
365
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
262
366
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
263
367
|
} break;
|
264
368
|
case GGML_TYPE_BF16: {
|
265
369
|
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
266
|
-
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
370
|
+
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
267
371
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
268
372
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
269
373
|
} break;
|
@@ -282,16 +386,19 @@ void ggml_cuda_op_mul_mat_vec(
|
|
282
386
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
283
387
|
|
284
388
|
const int64_t ne00 = src0->ne[0];
|
389
|
+
const int64_t ne10 = src1->ne[0];
|
390
|
+
const int64_t ne0 = dst->ne[0];
|
285
391
|
const int64_t row_diff = row_high - row_low;
|
286
392
|
|
287
|
-
|
288
|
-
|
289
|
-
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
393
|
+
const int id = ggml_cuda_get_device();
|
394
|
+
const int cc = ggml_cuda_info().devices[id].cc;
|
290
395
|
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
291
396
|
|
292
397
|
|
293
398
|
// ggml_cuda_op provides single, contiguous matrices
|
294
399
|
const int64_t stride_row = ne00;
|
400
|
+
const int64_t stride_col_y = ne10;
|
401
|
+
const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
|
295
402
|
const int64_t nchannels_x = 1;
|
296
403
|
const int64_t nchannels_y = 1;
|
297
404
|
const int64_t nchannels_dst = 1;
|
@@ -307,19 +414,19 @@ void ggml_cuda_op_mul_mat_vec(
|
|
307
414
|
switch (src0->type) {
|
308
415
|
case GGML_TYPE_F32: {
|
309
416
|
const float * src0_d = (const float *) src0_dd_i;
|
310
|
-
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
417
|
+
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
311
418
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
312
419
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
313
420
|
} break;
|
314
421
|
case GGML_TYPE_F16: {
|
315
422
|
const half * src0_d = (const half *) src0_dd_i;
|
316
|
-
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
423
|
+
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
317
424
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
318
425
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
319
426
|
} break;
|
320
427
|
case GGML_TYPE_BF16: {
|
321
428
|
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
322
|
-
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
429
|
+
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
323
430
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
324
431
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
325
432
|
} break;
|
@@ -334,3 +441,66 @@ void ggml_cuda_op_mul_mat_vec(
|
|
334
441
|
GGML_UNUSED(src1_ncols);
|
335
442
|
GGML_UNUSED(src1_padded_row_size);
|
336
443
|
}
|
444
|
+
|
445
|
+
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
|
446
|
+
if (src0_ne[0] % 2 != 0) {
|
447
|
+
return false;
|
448
|
+
}
|
449
|
+
switch (type) {
|
450
|
+
case GGML_TYPE_F32:
|
451
|
+
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
452
|
+
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
453
|
+
return ne11 <= 8;
|
454
|
+
}
|
455
|
+
if (cc >= GGML_CUDA_CC_TURING) {
|
456
|
+
return ne11 <= 4;
|
457
|
+
}
|
458
|
+
return ne11 <= 3;
|
459
|
+
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
|
460
|
+
if (fp32_mma_hardware_available(cc)) {
|
461
|
+
return ne11 <= 3;
|
462
|
+
}
|
463
|
+
return ne11 <= 8;
|
464
|
+
}
|
465
|
+
return ne11 <= 8;
|
466
|
+
case GGML_TYPE_F16:
|
467
|
+
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
468
|
+
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
469
|
+
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
470
|
+
return src0_small && ne11 <= 4;
|
471
|
+
}
|
472
|
+
if (fp16_mma_hardware_available(cc)) {
|
473
|
+
return src0_small && ne11 <= 3;
|
474
|
+
}
|
475
|
+
return ne11 <= 8;
|
476
|
+
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
|
477
|
+
if (fp16_mma_hardware_available(cc)) {
|
478
|
+
if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
479
|
+
return ne11 <= 5;
|
480
|
+
}
|
481
|
+
return ne11 <= 2;
|
482
|
+
}
|
483
|
+
return ne11 <= 8;
|
484
|
+
}
|
485
|
+
return ne11 <= 8;
|
486
|
+
case GGML_TYPE_BF16:
|
487
|
+
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
488
|
+
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
489
|
+
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
490
|
+
return src0_small && ne11 <= 4;
|
491
|
+
}
|
492
|
+
if (bf16_mma_hardware_available(cc)) {
|
493
|
+
return src0_small && ne11 <= 3;
|
494
|
+
}
|
495
|
+
return ne11 <= 8;
|
496
|
+
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
|
497
|
+
if (bf16_mma_hardware_available(cc)) {
|
498
|
+
return ne11 <= 3;
|
499
|
+
}
|
500
|
+
return ne11 <= 8;
|
501
|
+
}
|
502
|
+
return ne11 <= 8;
|
503
|
+
default:
|
504
|
+
return false;
|
505
|
+
}
|
506
|
+
}
|
@@ -1,8 +1,5 @@
|
|
1
1
|
#include "common.cuh"
|
2
2
|
|
3
|
-
// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
|
4
|
-
#define MMV_MAX_ROWS 512
|
5
|
-
|
6
3
|
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
7
4
|
|
8
5
|
void ggml_cuda_op_mul_mat_vec(
|
@@ -10,3 +7,5 @@ void ggml_cuda_op_mul_mat_vec(
|
|
10
7
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
11
8
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
12
9
|
const int64_t src1_padded_row_size, cudaStream_t stream);
|
10
|
+
|
11
|
+
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
|
@@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
|
|
10
10
|
float * __restrict__ dst, const int64_t L) {
|
11
11
|
GGML_UNUSED(src1_nb0);
|
12
12
|
GGML_UNUSED(src2_nb0);
|
13
|
+
|
14
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
13
15
|
const int bidx = blockIdx.x; // split along B
|
14
16
|
const int bidy = blockIdx.y; // split along D
|
15
17
|
const int tid = threadIdx.x;
|
@@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
|
|
44
46
|
if (N == 16) {
|
45
47
|
#pragma unroll
|
46
48
|
for (size_t i = 0; i < splitD / 4; i += 2) {
|
47
|
-
float value = A_block[(wid *
|
49
|
+
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
|
48
50
|
// todo: bank conflict
|
49
51
|
// I am always confused with how to use the swizzling method to solve
|
50
52
|
// bank conflit. Hoping somebody can tell me.
|
51
|
-
smem_A[(wid *
|
53
|
+
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
52
54
|
}
|
53
55
|
#pragma unroll
|
54
56
|
for (size_t i = 0; i < splitD / 4; i += 2) {
|
55
|
-
float value = s0_block[(wid *
|
56
|
-
smem_s0[(wid *
|
57
|
+
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
|
58
|
+
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
57
59
|
}
|
58
60
|
}
|
59
61
|
|
@@ -1,25 +1,9 @@
|
|
1
1
|
#include "sumrows.cuh"
|
2
2
|
|
3
|
-
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
|
4
|
-
const int row = blockIdx.x;
|
5
|
-
const int col = threadIdx.x;
|
6
|
-
|
7
|
-
float sum = 0.0f;
|
8
|
-
for (int i = col; i < ncols; i += blockDim.x) {
|
9
|
-
sum += x[row * ncols + i];
|
10
|
-
}
|
11
|
-
|
12
|
-
sum = warp_reduce_sum(sum);
|
13
|
-
|
14
|
-
if (col == 0) {
|
15
|
-
dst[row] = sum;
|
16
|
-
}
|
17
|
-
}
|
18
|
-
|
19
3
|
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
20
4
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
21
5
|
const dim3 block_nums(nrows, 1, 1);
|
22
|
-
|
6
|
+
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
23
7
|
}
|
24
8
|
|
25
9
|
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
@@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
35
19
|
const int64_t ncols = src0->ne[0];
|
36
20
|
const int64_t nrows = ggml_nrows(src0);
|
37
21
|
|
38
|
-
|
22
|
+
const dim3 block_dims(WARP_SIZE, 1, 1);
|
23
|
+
const dim3 block_nums(nrows, 1, 1);
|
24
|
+
|
25
|
+
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
39
26
|
}
|