whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -19,10 +19,10 @@
|
|
19
19
|
#endif
|
20
20
|
#include "ggml-common.h"
|
21
21
|
|
22
|
-
#include <cstdio>
|
23
22
|
#include <array>
|
24
23
|
#include <cassert>
|
25
24
|
#include <cfloat>
|
25
|
+
#include <cstdio>
|
26
26
|
#include <string>
|
27
27
|
#include <vector>
|
28
28
|
|
@@ -76,11 +76,9 @@
|
|
76
76
|
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
77
77
|
|
78
78
|
// Moore Threads
|
79
|
-
#define
|
80
|
-
|
81
|
-
#define
|
82
|
-
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
|
83
|
-
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
|
79
|
+
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
|
80
|
+
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
|
81
|
+
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
|
84
82
|
|
85
83
|
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
|
86
84
|
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
|
@@ -203,13 +201,13 @@ typedef float2 dfloat2;
|
|
203
201
|
#define FAST_FP16_AVAILABLE
|
204
202
|
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
|
205
203
|
|
206
|
-
#if !
|
204
|
+
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
|
207
205
|
#define FP16_MMA_AVAILABLE
|
208
|
-
#endif // !
|
206
|
+
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
|
209
207
|
|
210
|
-
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
|
208
|
+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
|
211
209
|
#define FP16_MMA_AVAILABLE
|
212
|
-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
|
210
|
+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
|
213
211
|
|
214
212
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
215
213
|
#define NEW_MMA_AVAILABLE
|
@@ -219,9 +217,9 @@ typedef float2 dfloat2;
|
|
219
217
|
#define CP_ASYNC_AVAILABLE
|
220
218
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
221
219
|
|
222
|
-
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) &&
|
220
|
+
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
|
223
221
|
#define FLASH_ATTN_AVAILABLE
|
224
|
-
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) &&
|
222
|
+
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
|
225
223
|
|
226
224
|
static bool fp16_available(const int cc) {
|
227
225
|
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
|
@@ -233,7 +231,8 @@ static bool fast_fp16_available(const int cc) {
|
|
233
231
|
|
234
232
|
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
235
233
|
static bool fast_fp16_hardware_available(const int cc) {
|
236
|
-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc)
|
234
|
+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) ||
|
235
|
+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
|
237
236
|
}
|
238
237
|
|
239
238
|
// Any FP16 tensor core instructions are available for ggml code.
|
@@ -241,15 +240,35 @@ static bool fp16_mma_available(const int cc) {
|
|
241
240
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
242
241
|
return false;
|
243
242
|
#else
|
244
|
-
|
245
|
-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
|
243
|
+
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
|
244
|
+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
|
245
|
+
GGML_CUDA_CC_IS_MTHREADS(cc)) {
|
246
|
+
return true;
|
247
|
+
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
248
|
+
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
249
|
+
return true;
|
250
|
+
#else
|
251
|
+
return false;
|
252
|
+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
253
|
+
} else {
|
254
|
+
return false;
|
255
|
+
}
|
246
256
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
247
257
|
}
|
248
258
|
|
249
259
|
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
250
260
|
static bool fp16_mma_hardware_available(const int cc) {
|
251
261
|
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
|
252
|
-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)
|
262
|
+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
|
263
|
+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
|
264
|
+
}
|
265
|
+
|
266
|
+
static bool bf16_mma_hardware_available(const int cc) {
|
267
|
+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
|
268
|
+
}
|
269
|
+
|
270
|
+
static bool fp32_mma_hardware_available(const int cc) {
|
271
|
+
return GGML_CUDA_CC_IS_CDNA(cc);
|
253
272
|
}
|
254
273
|
|
255
274
|
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
@@ -262,11 +281,11 @@ static bool cp_async_available(const int cc) {
|
|
262
281
|
}
|
263
282
|
|
264
283
|
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
265
|
-
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
266
|
-
return
|
284
|
+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
|
285
|
+
return 64;
|
267
286
|
#else
|
268
287
|
return 32;
|
269
|
-
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
288
|
+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
|
270
289
|
}
|
271
290
|
|
272
291
|
[[noreturn]]
|
@@ -362,6 +381,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|
362
381
|
#endif // FP16_AVAILABLE
|
363
382
|
}
|
364
383
|
|
384
|
+
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
|
385
|
+
template<bool norm>
|
386
|
+
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
|
387
|
+
const int row = blockIdx.x;
|
388
|
+
const int col = threadIdx.x;
|
389
|
+
|
390
|
+
float sum = 0.0f;
|
391
|
+
for (int i = col; i < ncols; i += blockDim.x) {
|
392
|
+
sum += x[row * ncols + i];
|
393
|
+
}
|
394
|
+
|
395
|
+
sum = warp_reduce_sum(sum);
|
396
|
+
|
397
|
+
if (col != 0) {
|
398
|
+
return;
|
399
|
+
}
|
400
|
+
|
401
|
+
dst[row] = norm ? sum / ncols : sum;
|
402
|
+
}
|
403
|
+
|
365
404
|
template<int width = WARP_SIZE>
|
366
405
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
367
406
|
#pragma unroll
|
@@ -466,9 +505,6 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
|
466
505
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
467
506
|
}
|
468
507
|
|
469
|
-
// TODO: move to ggml-common.h
|
470
|
-
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
471
|
-
|
472
508
|
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
473
509
|
|
474
510
|
static __device__ __forceinline__ float get_alibi_slope(
|
@@ -635,6 +671,7 @@ struct ggml_cuda_device_info {
|
|
635
671
|
int nsm; // number of streaming multiprocessors
|
636
672
|
size_t smpb; // max. shared memory per block
|
637
673
|
size_t smpbo; // max. shared memory per block (with opt-in)
|
674
|
+
bool integrated; // Device is integrated as opposed to discrete
|
638
675
|
bool vmm; // virtual memory support
|
639
676
|
size_t vmm_granularity; // granularity of virtual memory
|
640
677
|
size_t total_vram;
|
@@ -769,21 +806,7 @@ struct ggml_backend_cuda_context {
|
|
769
806
|
name(GGML_CUDA_NAME + std::to_string(device)) {
|
770
807
|
}
|
771
808
|
|
772
|
-
~ggml_backend_cuda_context()
|
773
|
-
if (copy_event != nullptr) {
|
774
|
-
CUDA_CHECK(cudaEventDestroy(copy_event));
|
775
|
-
}
|
776
|
-
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
|
777
|
-
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
|
778
|
-
if (streams[i][j] != nullptr) {
|
779
|
-
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
|
780
|
-
}
|
781
|
-
}
|
782
|
-
if (cublas_handles[i] != nullptr) {
|
783
|
-
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
|
784
|
-
}
|
785
|
-
}
|
786
|
-
}
|
809
|
+
~ggml_backend_cuda_context();
|
787
810
|
|
788
811
|
cudaStream_t stream(int device, int stream) {
|
789
812
|
if (streams[device][stream] == nullptr) {
|
@@ -0,0 +1,161 @@
|
|
1
|
+
#include "conv2d-dw.cuh"
|
2
|
+
|
3
|
+
struct conv_params {
|
4
|
+
int in_w, in_h;
|
5
|
+
int out_w, out_h;
|
6
|
+
int kernel_w, kernel_h;
|
7
|
+
int stride_x, stride_y;
|
8
|
+
int padding_x, padding_y;
|
9
|
+
int dilation_x, dilation_y;
|
10
|
+
int channels, batches;
|
11
|
+
};
|
12
|
+
|
13
|
+
struct kernel_bounds {
|
14
|
+
int y_min, y_max;
|
15
|
+
int x_min, x_max;
|
16
|
+
};
|
17
|
+
|
18
|
+
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
|
19
|
+
kernel_bounds bounds;
|
20
|
+
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
|
21
|
+
bounds.y_max =
|
22
|
+
min(params.kernel_h,
|
23
|
+
(params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
|
24
|
+
bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
|
25
|
+
bounds.x_max =
|
26
|
+
min(params.kernel_w,
|
27
|
+
(params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
|
28
|
+
return bounds;
|
29
|
+
}
|
30
|
+
|
31
|
+
__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
|
32
|
+
return out_coord * stride + kern_coord * dilation - padding;
|
33
|
+
}
|
34
|
+
|
35
|
+
struct whcn_layout {
|
36
|
+
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
|
37
|
+
return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
|
38
|
+
}
|
39
|
+
|
40
|
+
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
|
41
|
+
return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx;
|
42
|
+
}
|
43
|
+
|
44
|
+
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
|
45
|
+
return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h +
|
46
|
+
y * params.out_w + x;
|
47
|
+
}
|
48
|
+
|
49
|
+
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
|
50
|
+
int & out_x) {
|
51
|
+
out_x = global_idx % params.out_w;
|
52
|
+
out_y = (global_idx / params.out_w) % params.out_h;
|
53
|
+
c = (global_idx / (params.out_w * params.out_h)) % params.channels;
|
54
|
+
n = global_idx / (params.out_w * params.out_h * params.channels);
|
55
|
+
}
|
56
|
+
};
|
57
|
+
|
58
|
+
struct cwhn_layout {
|
59
|
+
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
|
60
|
+
return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c;
|
61
|
+
}
|
62
|
+
|
63
|
+
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
|
64
|
+
return (ky * params.kernel_w + kx) * params.channels + c;
|
65
|
+
}
|
66
|
+
|
67
|
+
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
|
68
|
+
return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) +
|
69
|
+
x * params.channels + c;
|
70
|
+
}
|
71
|
+
|
72
|
+
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
|
73
|
+
int & out_x) {
|
74
|
+
c = global_idx % params.channels;
|
75
|
+
out_x = (global_idx / params.channels) % params.out_w;
|
76
|
+
out_y = (global_idx / (params.channels * params.out_w)) % params.out_h;
|
77
|
+
n = global_idx / (params.channels * params.out_w * params.out_h);
|
78
|
+
}
|
79
|
+
};
|
80
|
+
|
81
|
+
template <typename T, typename Layout>
|
82
|
+
__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
|
83
|
+
const int in_w, const int in_h, const int out_w, const int out_h,
|
84
|
+
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
|
85
|
+
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
|
86
|
+
const int channels, const int batches) {
|
87
|
+
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
88
|
+
const int total_elements = batches * channels * out_h * out_w;
|
89
|
+
|
90
|
+
if (global_idx >= total_elements) {
|
91
|
+
return;
|
92
|
+
}
|
93
|
+
|
94
|
+
conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
|
95
|
+
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
|
96
|
+
|
97
|
+
int batch_idx, channel_idx, out_y_idx, out_x_idx;
|
98
|
+
Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
|
99
|
+
|
100
|
+
T accumulator = 0;
|
101
|
+
kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
|
102
|
+
|
103
|
+
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
|
104
|
+
int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
|
105
|
+
|
106
|
+
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
|
107
|
+
int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
|
108
|
+
|
109
|
+
const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
|
110
|
+
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
|
111
|
+
|
112
|
+
accumulator += input_val * kernel_val;
|
113
|
+
}
|
114
|
+
}
|
115
|
+
|
116
|
+
output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
|
117
|
+
}
|
118
|
+
|
119
|
+
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
120
|
+
const ggml_tensor * kernel = dst->src[0];
|
121
|
+
const ggml_tensor * input = dst->src[1];
|
122
|
+
|
123
|
+
GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
124
|
+
const float * w_d = (const float *) kernel->data;
|
125
|
+
const float * x_d = (const float *) input->data;
|
126
|
+
float * y_d = (float *) dst->data;
|
127
|
+
|
128
|
+
const int32_t * p = (const int32_t *) dst->op_params;
|
129
|
+
const int stride_x = p[0];
|
130
|
+
const int stride_y = p[1];
|
131
|
+
const int padding_x = p[2];
|
132
|
+
const int padding_y = p[3];
|
133
|
+
const int dilation_x = p[4];
|
134
|
+
const int dilation_y = p[5];
|
135
|
+
|
136
|
+
const int in_w = input->ne[0];
|
137
|
+
const int in_h = input->ne[1];
|
138
|
+
const int kernel_w = kernel->ne[0];
|
139
|
+
const int kernel_h = kernel->ne[1];
|
140
|
+
const int out_w = dst->ne[0];
|
141
|
+
const int out_h = dst->ne[1];
|
142
|
+
const int channels = dst->ne[2];
|
143
|
+
const int batches = dst->ne[3];
|
144
|
+
|
145
|
+
cudaStream_t st = ctx.stream();
|
146
|
+
|
147
|
+
const int total = batches * channels * out_h * out_w;
|
148
|
+
const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;
|
149
|
+
|
150
|
+
if (ggml_is_contiguous(input)) {
|
151
|
+
conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
|
152
|
+
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
|
153
|
+
dilation_x, dilation_y, channels, batches);
|
154
|
+
} else if (ggml_is_contiguous_channels(input)) {
|
155
|
+
conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
|
156
|
+
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
|
157
|
+
dilation_x, dilation_y, channels, batches);
|
158
|
+
} else {
|
159
|
+
GGML_ABORT("Unsupported memory layout for conv_2d_dw");
|
160
|
+
}
|
161
|
+
}
|
@@ -0,0 +1,91 @@
|
|
1
|
+
#include <algorithm>
|
2
|
+
|
3
|
+
#include "conv2d-transpose.cuh"
|
4
|
+
#include "ggml.h"
|
5
|
+
|
6
|
+
__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
|
7
|
+
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
|
8
|
+
const int out_h, const int kernel_w, const int kernel_h, const int stride,
|
9
|
+
const int c_in, const int c_out, const int batches) {
|
10
|
+
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
11
|
+
|
12
|
+
const int total_elements = out_w * out_h * c_out * batches;
|
13
|
+
|
14
|
+
if (global_idx >= total_elements) {
|
15
|
+
return;
|
16
|
+
}
|
17
|
+
|
18
|
+
const int out_x_idx = global_idx % out_w;
|
19
|
+
const int out_y_idx = (global_idx / out_w) % out_h;
|
20
|
+
const int c_idx = (global_idx / (out_w * out_h)) % c_out;
|
21
|
+
const int n_idx = global_idx / (out_w * out_h * c_out);
|
22
|
+
|
23
|
+
float accumulator = 0;
|
24
|
+
// For each output idx, find the inputs that contribute to it by checking stride alignment and bounds
|
25
|
+
|
26
|
+
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
|
27
|
+
for (int kh = 0; kh < kernel_h; ++kh) {
|
28
|
+
int in_y = out_y_idx - kh;
|
29
|
+
if (in_y < 0 || in_y % stride) continue;
|
30
|
+
in_y /= stride;
|
31
|
+
if (in_y >= in_h) continue;
|
32
|
+
|
33
|
+
for (int kw = 0; kw < kernel_w; ++kw) {
|
34
|
+
int in_x = out_x_idx - kw;
|
35
|
+
if (in_x < 0 || in_x % stride) continue;
|
36
|
+
in_x /= stride;
|
37
|
+
if (in_x >= in_w) continue;
|
38
|
+
|
39
|
+
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
|
40
|
+
const int kernel_idx =
|
41
|
+
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
|
42
|
+
|
43
|
+
float input_val = input[input_idx];
|
44
|
+
half kern_val = kernel[kernel_idx];
|
45
|
+
|
46
|
+
accumulator += input_val * (float) kern_val;
|
47
|
+
}
|
48
|
+
}
|
49
|
+
}
|
50
|
+
|
51
|
+
output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;
|
52
|
+
}
|
53
|
+
|
54
|
+
//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)
|
55
|
+
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
56
|
+
const ggml_tensor * kernel = dst->src[0];
|
57
|
+
const ggml_tensor * input = dst->src[1];
|
58
|
+
|
59
|
+
GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
60
|
+
|
61
|
+
const float * input_data = (const float *) input->data;
|
62
|
+
float * output_data = (float *) dst->data;
|
63
|
+
const half * kernel_data = (const half *) kernel->data;
|
64
|
+
|
65
|
+
const int input_w = input->ne[0];
|
66
|
+
const int input_h = input->ne[1];
|
67
|
+
const int output_w = dst->ne[0];
|
68
|
+
const int output_h = dst->ne[1];
|
69
|
+
const int channels_in = input->ne[2];
|
70
|
+
const int channels_out = kernel->ne[2];
|
71
|
+
const int kernel_w = kernel->ne[0];
|
72
|
+
const int kernel_h = kernel->ne[1];
|
73
|
+
const int stride = dst->op_params[0];
|
74
|
+
const int batches = input->ne[3];
|
75
|
+
|
76
|
+
GGML_ASSERT(channels_in == kernel->ne[3]);
|
77
|
+
GGML_ASSERT(stride > 0);
|
78
|
+
|
79
|
+
cudaStream_t st = ctx.stream();
|
80
|
+
|
81
|
+
GGML_ASSERT(ggml_is_contiguous(input));
|
82
|
+
GGML_ASSERT(ggml_is_contiguous(kernel));
|
83
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
84
|
+
|
85
|
+
const int total = (output_w * output_h * channels_out * batches);
|
86
|
+
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
|
87
|
+
|
88
|
+
conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
|
89
|
+
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
|
90
|
+
channels_in, channels_out, batches);
|
91
|
+
}
|
@@ -728,3 +728,25 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
|
|
728
728
|
return nullptr;
|
729
729
|
}
|
730
730
|
}
|
731
|
+
|
732
|
+
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
|
733
|
+
switch (type) {
|
734
|
+
case GGML_TYPE_F32:
|
735
|
+
return convert_unary_cuda<float, nv_bfloat16>;
|
736
|
+
case GGML_TYPE_F16:
|
737
|
+
return convert_unary_cuda<half, nv_bfloat16>;
|
738
|
+
default:
|
739
|
+
return nullptr;
|
740
|
+
}
|
741
|
+
}
|
742
|
+
|
743
|
+
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
|
744
|
+
switch (type) {
|
745
|
+
case GGML_TYPE_F16:
|
746
|
+
return convert_unary_cuda<half, float>;
|
747
|
+
case GGML_TYPE_BF16:
|
748
|
+
return convert_unary_cuda<nv_bfloat16, float>;
|
749
|
+
default:
|
750
|
+
return nullptr;
|
751
|
+
}
|
752
|
+
}
|
@@ -22,5 +22,10 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
|
|
22
22
|
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
|
23
23
|
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
|
24
24
|
|
25
|
+
typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
|
25
26
|
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
|
27
|
+
typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
|
28
|
+
|
29
|
+
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
|
26
30
|
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
|
31
|
+
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
|
@@ -623,8 +623,8 @@ static __global__ void flash_attn_combine_results(
|
|
623
623
|
__builtin_assume(tid < D);
|
624
624
|
|
625
625
|
extern __shared__ float2 meta[];
|
626
|
-
|
627
|
-
((float *) meta)[
|
626
|
+
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
627
|
+
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
|
628
628
|
}
|
629
629
|
|
630
630
|
__syncthreads();
|
@@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
652
652
|
float KQ_max_scale[cols_per_thread];
|
653
653
|
#pragma unroll
|
654
654
|
for (int col = 0; col < cols_per_thread; ++col) {
|
655
|
-
|
655
|
+
const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
|
656
|
+
KQ_max_scale[col] = expf(KQ_max_diff);
|
656
657
|
KQ_max[col] = KQ_max_new[col];
|
657
658
|
|
659
|
+
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
660
|
+
|
658
661
|
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
659
662
|
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
|
660
663
|
}
|
@@ -1246,7 +1249,7 @@ static __global__ void flash_attn_ext_f16(
|
|
1246
1249
|
NO_DEVICE_CODE;
|
1247
1250
|
return;
|
1248
1251
|
}
|
1249
|
-
#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
1252
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
1250
1253
|
|
1251
1254
|
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
1252
1255
|
|
@@ -9,7 +9,11 @@
|
|
9
9
|
#ifdef FP16_MMA_AVAILABLE
|
10
10
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
11
11
|
#include <mma.h>
|
12
|
+
#ifdef GGML_USE_MUSA
|
13
|
+
namespace wmma = mtmusa::wmma;
|
14
|
+
#else // GGML_USE_MUSA
|
12
15
|
namespace wmma = nvcuda::wmma;
|
16
|
+
#endif // GGML_USE_MUSA
|
13
17
|
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
|
14
18
|
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
|
15
19
|
#include <rocwmma/rocwmma.hpp>
|