whispercpp 1.3.5 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/LICENSE +1 -1
- data/README.md +133 -3
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -7
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +56 -46
- data/ext/ruby_whisper.h +165 -2
- data/ext/ruby_whisper_context.c +297 -126
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -66
- data/ext/ruby_whisper_segment.c +6 -7
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +46 -16
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +24 -19
- data/ext/sources/examples/cli/cli.cpp +51 -9
- data/ext/sources/examples/common-ggml.cpp +4 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +213 -163
- data/ext/sources/ggml/CMakeLists.txt +29 -15
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +73 -11
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +8 -3
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +155 -16
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +25 -5
- data/ext/sources/ggml/src/ggml-alloc.c +9 -10
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
- data/ext/sources/ggml/src/ggml-common.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
- data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
- data/ext/sources/ggml/src/ggml-impl.h +68 -1
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +385 -119
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
- data/ext/sources/ggml/src/ggml.c +268 -52
- data/ext/sources/ggml/src/gguf.cpp +377 -47
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +62 -40
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +445 -55
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_context_params.rb +82 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +44 -6
- data/whispercpp.gemspec +2 -2
- metadata +426 -280
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
- data/ext/sources/examples/talk-llama/llama-context.h +0 -360
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
- data/ext/sources/examples/talk-llama/llama-model.h +0 -544
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
- data/ext/sources/examples/talk-llama/llama.h +0 -1540
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -569
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
|
@@ -61,11 +61,24 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
|
|
61
61
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true);
|
|
62
62
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true);
|
|
63
63
|
|
|
64
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 4, 64, 96, 64, 64, 2, true);
|
|
65
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 4, 32, 96, 64, 64, 2, true);
|
|
66
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 32, 96, 64, 64, 2, true);
|
|
67
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 32, 96, 64, 64, 2, true);
|
|
68
|
+
|
|
64
69
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true);
|
|
65
70
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true);
|
|
66
71
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
|
67
72
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
|
68
73
|
|
|
74
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
|
75
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
|
76
|
+
|
|
77
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false);
|
|
78
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false);
|
|
79
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
|
80
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
|
81
|
+
|
|
69
82
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
|
70
83
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
|
71
84
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
|
@@ -80,6 +93,14 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
|
|
80
93
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
|
81
94
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
|
82
95
|
|
|
96
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
|
97
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
|
98
|
+
|
|
99
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
|
100
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
|
101
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
|
102
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
|
103
|
+
|
|
83
104
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
|
84
105
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
|
85
106
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
|
@@ -89,6 +110,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
|
|
89
110
|
}
|
|
90
111
|
|
|
91
112
|
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
|
113
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 64, 1, false);
|
|
114
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 64, 1, false);
|
|
115
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 64, 1, false);
|
|
116
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 64, 1, false);
|
|
117
|
+
|
|
92
118
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
|
93
119
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
|
94
120
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
|
@@ -98,6 +124,110 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
|
|
98
124
|
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
99
125
|
}
|
|
100
126
|
|
|
127
|
+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
|
|
128
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 64, 32, 32, 32, 1, true);
|
|
129
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true);
|
|
130
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true);
|
|
131
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 1, true);
|
|
132
|
+
|
|
133
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 64, 2, 32, 40, 40, 40, 1, true);
|
|
134
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 64, 2, 32, 40, 40, 40, 1, true);
|
|
135
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true);
|
|
136
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 1, true);
|
|
137
|
+
|
|
138
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 64, 2, 32, 48, 48, 48, 1, true);
|
|
139
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 64, 2, 32, 48, 48, 48, 1, true);
|
|
140
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true);
|
|
141
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 1, true);
|
|
142
|
+
|
|
143
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 64, 2, 32, 56, 56, 56, 1, true);
|
|
144
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 64, 2, 32, 56, 56, 56, 1, true);
|
|
145
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true);
|
|
146
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 1, true);
|
|
147
|
+
|
|
148
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 64, 2, 32, 64, 64, 64, 1, true);
|
|
149
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 64, 2, 32, 64, 64, 64, 1, true);
|
|
150
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true);
|
|
151
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 1, true);
|
|
152
|
+
|
|
153
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 2, 32, 96, 64, 64, 1, true);
|
|
154
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 2, 32, 96, 64, 64, 1, true);
|
|
155
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 64, 96, 64, 64, 1, true);
|
|
156
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 64, 96, 64, 64, 1, true);
|
|
157
|
+
|
|
158
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 2, 32, 128, 128, 128, 1, true);
|
|
159
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 2, 32, 128, 128, 128, 1, true);
|
|
160
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 1, true);
|
|
161
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 1, true);
|
|
162
|
+
|
|
163
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 160, 128, 128, 1, true);
|
|
164
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 32, 160, 128, 128, 1, true);
|
|
165
|
+
|
|
166
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 128, 3, 64, 96, 64, 128, 1, true);
|
|
167
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 128, 3, 64, 96, 64, 128, 1, true);
|
|
168
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, true);
|
|
169
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 128, 2, 32, 128, 128, 128, 1, true);
|
|
170
|
+
|
|
171
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 128, 3, 64, 96, 64, 128, 1, true);
|
|
172
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 128, 3, 64, 96, 64, 128, 1, true);
|
|
173
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, true);
|
|
174
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 128, 2, 32, 160, 128, 128, 1, true);
|
|
175
|
+
|
|
176
|
+
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {
|
|
180
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 1, 64, 32, 32, 32, 1, true);
|
|
181
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 32, 32, 32, 1, true);
|
|
182
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 32, 32, 32, 1, true);
|
|
183
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 4, 64, 32, 32, 32, 1, true);
|
|
184
|
+
|
|
185
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40, 40, 40, 1, true);
|
|
186
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40, 40, 40, 1, true);
|
|
187
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40, 40, 40, 1, true);
|
|
188
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true);
|
|
189
|
+
|
|
190
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48, 48, 48, 1, true);
|
|
191
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48, 48, 48, 1, true);
|
|
192
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48, 48, 48, 1, true);
|
|
193
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true);
|
|
194
|
+
|
|
195
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56, 56, 56, 1, true);
|
|
196
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56, 56, 56, 1, true);
|
|
197
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56, 56, 56, 1, true);
|
|
198
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true);
|
|
199
|
+
|
|
200
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64, 64, 64, 1, true);
|
|
201
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64, 64, 64, 1, true);
|
|
202
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64, 64, 64, 1, true);
|
|
203
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true);
|
|
204
|
+
|
|
205
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 256, 1, 64, 64, 64, 64, 1, true);
|
|
206
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 256, 1, 64, 64, 64, 64, 1, true);
|
|
207
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 256, 1, 64, 64, 64, 64, 1, true);
|
|
208
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 512, 1, 64, 64, 64, 64, 1, true);
|
|
209
|
+
|
|
210
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 256, 1, 64, 128, 128, 128, 1, true);
|
|
211
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 256, 1, 64, 128, 128, 128, 1, true);
|
|
212
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 256, 1, 64, 128, 128, 128, 1, true);
|
|
213
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 512, 1, 64, 128, 128, 64, 1, true);
|
|
214
|
+
|
|
215
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 256, 1, 64, 160, 128, 128, 1, true);
|
|
216
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 64, 160, 128, 128, 1, true);
|
|
217
|
+
|
|
218
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 256, 1, 64, 128, 128, 128, 1, true);
|
|
219
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 256, 1, 64, 128, 128, 128, 1, true);
|
|
220
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 256, 1, 64, 128, 128, 128, 1, true);
|
|
221
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 64, 128, 128, 128, 1, true);
|
|
222
|
+
|
|
223
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 256, 1, 64, 128, 128, 128, 1, true);
|
|
224
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 256, 1, 64, 128, 128, 128, 1, true);
|
|
225
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 256, 1, 64, 160, 128, 128, 1, true);
|
|
226
|
+
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 64, 160, 128, 128, 1, true);
|
|
227
|
+
|
|
228
|
+
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
|
|
229
|
+
}
|
|
230
|
+
|
|
101
231
|
static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
102
232
|
if (ampere_mma_available(cc)) {
|
|
103
233
|
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
@@ -105,6 +235,12 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
|
|
|
105
235
|
if (turing_mma_available(cc)) {
|
|
106
236
|
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
|
107
237
|
}
|
|
238
|
+
if (amd_mfma_available(cc)) {
|
|
239
|
+
return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
|
|
240
|
+
}
|
|
241
|
+
if (amd_wmma_available(cc)) {
|
|
242
|
+
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
|
|
243
|
+
}
|
|
108
244
|
GGML_ASSERT(volta_mma_available(cc));
|
|
109
245
|
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
|
|
110
246
|
}
|
|
@@ -114,8 +250,12 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
|
|
|
114
250
|
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
115
251
|
#elif defined(TURING_MMA_AVAILABLE)
|
|
116
252
|
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
|
253
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
254
|
+
return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
|
|
117
255
|
#elif defined(VOLTA_MMA_AVAILABLE)
|
|
118
256
|
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
|
|
257
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
258
|
+
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
|
|
119
259
|
#else
|
|
120
260
|
GGML_UNUSED_VARS(DKQ, DV, ncols);
|
|
121
261
|
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
|
|
@@ -186,6 +326,23 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
|
|
|
186
326
|
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
|
|
187
327
|
}
|
|
188
328
|
|
|
329
|
+
static constexpr __device__ int get_cols_per_thread() {
|
|
330
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
331
|
+
return 1; // AMD has a single column per thread.
|
|
332
|
+
#else
|
|
333
|
+
return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
|
334
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
static __host__ int get_cols_per_warp(const int cc) {
|
|
338
|
+
if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) {
|
|
339
|
+
return 16;
|
|
340
|
+
} else {
|
|
341
|
+
// Volta
|
|
342
|
+
return 32;
|
|
343
|
+
}
|
|
344
|
+
}
|
|
345
|
+
|
|
189
346
|
// ------------------------------------------------------------------------------------------------------------------
|
|
190
347
|
|
|
191
348
|
static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
|
|
@@ -206,21 +363,23 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c
|
|
|
206
363
|
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
|
|
207
364
|
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
208
365
|
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
|
|
366
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
209
367
|
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
|
210
|
-
// The minimum granularity
|
|
368
|
+
// The minimum granularity is 16 bytes.
|
|
369
|
+
constexpr int h2_per_chunk = 16/sizeof(half2);
|
|
370
|
+
const int chunks_per_row = D2 / h2_per_chunk;
|
|
211
371
|
if constexpr (use_cp_async) {
|
|
372
|
+
static_assert(warp_size == 32, "bad warp_size");
|
|
212
373
|
static_assert(!oob_check, "OOB check not compatible with cp_async");
|
|
213
374
|
constexpr int preload = 64;
|
|
214
|
-
constexpr int h2_per_chunk = 16/sizeof(half2);
|
|
215
|
-
const int chunks_per_row = D2 / h2_per_chunk;
|
|
216
375
|
|
|
217
376
|
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
|
218
377
|
|
|
219
378
|
auto load = [&] __device__ (auto n) {
|
|
220
|
-
const int stride_k =
|
|
221
|
-
const int k0_start = stride_k ==
|
|
379
|
+
const int stride_k = warp_size >> n;
|
|
380
|
+
const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
|
222
381
|
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
|
223
|
-
const int stride_i =
|
|
382
|
+
const int stride_i = warp_size / stride_k;
|
|
224
383
|
|
|
225
384
|
if (k0_start == k0_stop) {
|
|
226
385
|
return;
|
|
@@ -228,7 +387,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
228
387
|
|
|
229
388
|
#pragma unroll
|
|
230
389
|
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
|
231
|
-
const int i = i0 + threadIdx.y*stride_i + (stride_k ==
|
|
390
|
+
const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
|
232
391
|
|
|
233
392
|
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
|
234
393
|
break;
|
|
@@ -236,7 +395,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
236
395
|
|
|
237
396
|
#pragma unroll
|
|
238
397
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
239
|
-
const int k = k0 + (stride_k ==
|
|
398
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
240
399
|
|
|
241
400
|
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
|
|
242
401
|
}
|
|
@@ -250,12 +409,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
250
409
|
// 6: max 1*16= 16 bytes, 8 half
|
|
251
410
|
ggml_cuda_unroll<6>{}(load);
|
|
252
411
|
} else {
|
|
253
|
-
|
|
412
|
+
const half2 zero[4] = {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}};
|
|
254
413
|
auto load = [&] __device__ (const int n) {
|
|
255
|
-
const int stride_k =
|
|
256
|
-
const int k0_start = stride_k ==
|
|
257
|
-
const int k0_stop =
|
|
258
|
-
const int stride_i =
|
|
414
|
+
const int stride_k = 32 >> n;
|
|
415
|
+
const int k0_start = stride_k == 32 ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
|
416
|
+
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
|
417
|
+
const int stride_i = warp_size / stride_k;
|
|
259
418
|
|
|
260
419
|
if (k0_start == k0_stop) {
|
|
261
420
|
return;
|
|
@@ -263,7 +422,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
263
422
|
|
|
264
423
|
#pragma unroll
|
|
265
424
|
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
|
266
|
-
const int i = i0 + threadIdx.y*stride_i + (stride_k ==
|
|
425
|
+
const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
|
267
426
|
|
|
268
427
|
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
|
269
428
|
break;
|
|
@@ -271,17 +430,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
271
430
|
|
|
272
431
|
#pragma unroll
|
|
273
432
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
274
|
-
const int k = k0 + (stride_k ==
|
|
433
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
275
434
|
|
|
276
|
-
tile_KV
|
|
435
|
+
ggml_cuda_memcpy_1<16>(tile_KV + i*stride_tile + k*4,
|
|
436
|
+
!oob_check || i < i_sup ? KV + i*stride_KV + k*h2_per_chunk : zero);
|
|
277
437
|
}
|
|
278
438
|
}
|
|
279
439
|
};
|
|
280
|
-
// 1: max 32*
|
|
281
|
-
// 2: max 16*
|
|
282
|
-
// 3: max 8*
|
|
283
|
-
// 4: max 4*
|
|
284
|
-
|
|
440
|
+
// 1: max 32*16=512 bytes, 256 half
|
|
441
|
+
// 2: max 16*16=256 bytes, 128 half
|
|
442
|
+
// 3: max 8*16=128 bytes, 64 half
|
|
443
|
+
// 4: max 4*16= 64 bytes, 32 half
|
|
444
|
+
// 5: max 2*16= 32 bytes, 16 half
|
|
445
|
+
// 6: max 1*16= 16 bytes, 8 half
|
|
446
|
+
ggml_cuda_unroll<6>{}(load);
|
|
285
447
|
}
|
|
286
448
|
}
|
|
287
449
|
|
|
@@ -289,18 +451,19 @@ template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_chec
|
|
|
289
451
|
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
290
452
|
const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
|
|
291
453
|
const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
|
|
454
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
292
455
|
if constexpr (use_cp_async) {
|
|
293
|
-
static_assert(nbatch_fa <= 8*
|
|
456
|
+
static_assert(nbatch_fa <= 8*warp_size && nbatch_fa % 8 == 0, "bad nbatch_fa");
|
|
294
457
|
static_assert(!oob_check, "OOB check incompatible with cp_async");
|
|
295
458
|
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
|
|
296
|
-
constexpr int cols_per_warp = 8*
|
|
459
|
+
constexpr int cols_per_warp = 8*warp_size/nbatch_fa;
|
|
297
460
|
constexpr int stride_j = nwarps * cols_per_warp;
|
|
298
461
|
|
|
299
462
|
const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
|
|
300
463
|
|
|
301
464
|
#pragma unroll
|
|
302
465
|
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
|
303
|
-
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (
|
|
466
|
+
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
|
|
304
467
|
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
|
305
468
|
|
|
306
469
|
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
|
@@ -309,7 +472,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
|
309
472
|
|
|
310
473
|
const int i = 8 * (threadIdx.x % (nbatch_fa/8));
|
|
311
474
|
|
|
312
|
-
cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
|
|
475
|
+
cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + int64_t(j_vram)*stride_mask + i);
|
|
313
476
|
}
|
|
314
477
|
} else if constexpr (oob_check) {
|
|
315
478
|
#pragma unroll
|
|
@@ -322,27 +485,27 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
|
322
485
|
}
|
|
323
486
|
|
|
324
487
|
#pragma unroll
|
|
325
|
-
for (int i0 = 0; i0 < nbatch_fa; i0 +=
|
|
488
|
+
for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
|
|
326
489
|
const int i = i0 + threadIdx.x;
|
|
327
490
|
|
|
328
|
-
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
|
|
491
|
+
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[int64_t(j_vram)*stride_mask + i] : half(0.0f);
|
|
329
492
|
}
|
|
330
493
|
}
|
|
331
|
-
} else if constexpr (nbatch_fa < 2*
|
|
332
|
-
constexpr int cols_per_warp = 2*
|
|
494
|
+
} else if constexpr (nbatch_fa < 2*warp_size) {
|
|
495
|
+
constexpr int cols_per_warp = 2*warp_size/nbatch_fa;
|
|
333
496
|
constexpr int stride_j = nwarps * cols_per_warp;
|
|
334
497
|
#pragma unroll
|
|
335
498
|
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
|
336
|
-
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (
|
|
499
|
+
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
|
|
337
500
|
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
|
338
501
|
|
|
339
502
|
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
|
340
503
|
break;
|
|
341
504
|
}
|
|
342
505
|
|
|
343
|
-
const int i = threadIdx.x % (
|
|
506
|
+
const int i = threadIdx.x % (warp_size/cols_per_warp);
|
|
344
507
|
|
|
345
|
-
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
|
|
508
|
+
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + int64_t(j_vram)*stride_mask + 2*i);
|
|
346
509
|
}
|
|
347
510
|
} else {
|
|
348
511
|
#pragma unroll
|
|
@@ -355,17 +518,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
|
355
518
|
}
|
|
356
519
|
|
|
357
520
|
#pragma unroll
|
|
358
|
-
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*
|
|
521
|
+
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
|
|
359
522
|
const int i = i0 + 2*threadIdx.x;
|
|
360
523
|
|
|
361
|
-
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
|
|
524
|
+
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + int64_t(j_vram)*stride_mask + i);
|
|
362
525
|
}
|
|
363
526
|
}
|
|
364
527
|
}
|
|
365
528
|
}
|
|
366
529
|
|
|
367
530
|
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
|
|
368
|
-
bool use_logit_softcap, bool
|
|
531
|
+
bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
|
|
369
532
|
typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
|
|
370
533
|
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
371
534
|
const float2 * const __restrict__ Q_f2,
|
|
@@ -393,33 +556,34 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
393
556
|
const int jt,
|
|
394
557
|
const int kb0,
|
|
395
558
|
const int k_VKQ_sup) {
|
|
396
|
-
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
559
|
+
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
560
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
397
561
|
constexpr int ncols = ncols1 * ncols2;
|
|
398
562
|
constexpr int cols_per_warp = T_B_KQ::I;
|
|
399
|
-
constexpr int cols_per_thread =
|
|
400
|
-
constexpr int np = nwarps *
|
|
563
|
+
constexpr int cols_per_thread = get_cols_per_thread();
|
|
564
|
+
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
|
401
565
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
|
402
566
|
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
|
403
567
|
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
|
404
568
|
constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
|
|
405
569
|
constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
|
|
406
570
|
|
|
407
|
-
constexpr int stride_tile_Q = DKQ/2 + 4;
|
|
408
571
|
constexpr int stride_tile_K = nbatch_K2 + 4;
|
|
409
572
|
|
|
410
|
-
|
|
411
|
-
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
|
573
|
+
constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
|
412
574
|
|
|
413
575
|
const int k_VKQ_0 = kb0 * nbatch_fa;
|
|
414
576
|
#if defined(TURING_MMA_AVAILABLE)
|
|
415
577
|
T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
|
|
578
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
579
|
+
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
|
416
580
|
#else // Volta
|
|
417
581
|
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
|
418
582
|
#endif // defined(TURING_MMA_AVAILABLE)
|
|
419
583
|
|
|
420
584
|
if constexpr (nstages > 1) {
|
|
421
585
|
static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
|
|
422
|
-
static_assert(!
|
|
586
|
+
static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
|
423
587
|
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
|
424
588
|
constexpr bool use_cp_async = true;
|
|
425
589
|
cp_async_wait_all();
|
|
@@ -434,12 +598,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
434
598
|
}
|
|
435
599
|
}
|
|
436
600
|
|
|
601
|
+
// For MLA K and V have the same data.
|
|
602
|
+
// Therefore, iterate over K in reverse and later re-use the data if possible.
|
|
437
603
|
#pragma unroll
|
|
438
|
-
for (int k0_start =
|
|
604
|
+
for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
|
|
439
605
|
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
|
440
|
-
const int k0_diff = k0_stop - k0_start;
|
|
441
606
|
|
|
442
607
|
if constexpr (nstages <= 1) {
|
|
608
|
+
const int k0_diff = k0_stop - k0_start;
|
|
443
609
|
constexpr bool use_cp_async = nstages == 1;
|
|
444
610
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
445
611
|
(K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup);
|
|
@@ -461,13 +627,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
461
627
|
if constexpr (cols_per_warp == 8) {
|
|
462
628
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
|
463
629
|
} else {
|
|
464
|
-
// Wide version of KQ_C is column-major
|
|
630
|
+
// Wide version of KQ_C is column-major
|
|
631
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
632
|
+
// AMD matrix C is column-major.
|
|
633
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
|
634
|
+
#else
|
|
635
|
+
// swap A and B for CUDA.
|
|
465
636
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
|
|
637
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
466
638
|
}
|
|
467
639
|
}
|
|
468
640
|
}
|
|
469
641
|
} else {
|
|
470
|
-
|
|
642
|
+
constexpr int stride_tile_Q = DKQ/2 + 4;
|
|
471
643
|
#pragma unroll
|
|
472
644
|
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
|
473
645
|
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
|
@@ -479,8 +651,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
479
651
|
T_A_KQ K_A;
|
|
480
652
|
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
|
481
653
|
|
|
482
|
-
|
|
483
|
-
|
|
654
|
+
if constexpr (cols_per_warp == 8) {
|
|
655
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
|
656
|
+
} else {
|
|
657
|
+
// Wide version of KQ_C is column-major
|
|
658
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
659
|
+
// AMD matrix C is column-major.
|
|
660
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
|
661
|
+
#else
|
|
662
|
+
// swap A and B for CUDA.
|
|
663
|
+
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
|
664
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
665
|
+
}
|
|
484
666
|
}
|
|
485
667
|
}
|
|
486
668
|
}
|
|
@@ -532,7 +714,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
532
714
|
#pragma unroll
|
|
533
715
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
534
716
|
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
|
535
|
-
|
|
717
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
718
|
+
constexpr int KQ_idx = 0;
|
|
719
|
+
#else
|
|
720
|
+
// Turing + Volta:
|
|
721
|
+
const int KQ_idx = l % 2;
|
|
722
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
723
|
+
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
536
724
|
}
|
|
537
725
|
}
|
|
538
726
|
}
|
|
@@ -542,7 +730,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
542
730
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
543
731
|
#pragma unroll
|
|
544
732
|
for (int offset = 16; offset >= 4; offset >>= 1) {
|
|
545
|
-
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset,
|
|
733
|
+
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
|
|
546
734
|
}
|
|
547
735
|
}
|
|
548
736
|
|
|
@@ -552,8 +740,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
552
740
|
#pragma unroll
|
|
553
741
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
554
742
|
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
|
555
|
-
|
|
556
|
-
|
|
743
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
744
|
+
constexpr int KQ_idx = 0;
|
|
745
|
+
#else
|
|
746
|
+
// Turing + Volta:
|
|
747
|
+
const int KQ_idx = l % 2;
|
|
748
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
749
|
+
KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
|
|
750
|
+
KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
|
|
557
751
|
} else {
|
|
558
752
|
KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
|
|
559
753
|
}
|
|
@@ -564,6 +758,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
564
758
|
#pragma unroll
|
|
565
759
|
for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) {
|
|
566
760
|
const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J;
|
|
761
|
+
|
|
762
|
+
// The mask is stored as 16 bit half values, loading them as 32 bit half2 values is preferred in terms of speed.
|
|
763
|
+
// However, this is not possible for RDNA3 where 2 consecutive l indices are not consecutive in the mask memory layout.
|
|
764
|
+
#ifdef RDNA3
|
|
765
|
+
#pragma unroll
|
|
766
|
+
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
767
|
+
const int i = i0 + T_C_KQ::get_j(l);
|
|
768
|
+
const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l)) / ncols2;
|
|
769
|
+
|
|
770
|
+
KQ_C[i00/(np*T_C_KQ::J)].x[l] += __half2float(tile_mask[j*(nbatch_fa + 8) + i]);
|
|
771
|
+
}
|
|
772
|
+
#else
|
|
567
773
|
#pragma unroll
|
|
568
774
|
for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) {
|
|
569
775
|
const int i = (i0 + T_C_KQ::get_j(l0)) / 2;
|
|
@@ -573,6 +779,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
573
779
|
KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x;
|
|
574
780
|
KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y;
|
|
575
781
|
}
|
|
782
|
+
#endif // RDNA3
|
|
576
783
|
}
|
|
577
784
|
}
|
|
578
785
|
|
|
@@ -584,8 +791,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
584
791
|
#pragma unroll
|
|
585
792
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
586
793
|
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
|
794
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
795
|
+
constexpr int KQ_idx = 0;
|
|
796
|
+
#else
|
|
587
797
|
// Turing + Volta:
|
|
588
|
-
|
|
798
|
+
const int KQ_idx = (l/2) % 2;
|
|
799
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
800
|
+
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
589
801
|
}
|
|
590
802
|
}
|
|
591
803
|
}
|
|
@@ -596,14 +808,22 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
596
808
|
// Values per KQ column are spread across 4 threads:
|
|
597
809
|
constexpr int offset_first = 2;
|
|
598
810
|
constexpr int offset_last = 1;
|
|
599
|
-
#
|
|
811
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
812
|
+
// MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16).
|
|
813
|
+
constexpr int offset_first = 32;
|
|
814
|
+
constexpr int offset_last = 16;
|
|
815
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
816
|
+
// Values per KQ column are spread across 2 threads:
|
|
817
|
+
constexpr int offset_first = 16;
|
|
818
|
+
constexpr int offset_last = 16;
|
|
819
|
+
#else // Volta
|
|
600
820
|
// Values per KQ column are spread across 2 threads:
|
|
601
821
|
constexpr int offset_first = 2;
|
|
602
822
|
constexpr int offset_last = 2;
|
|
603
823
|
#endif // defined(TURING_MMA_AVAILABLE)
|
|
604
824
|
#pragma unroll
|
|
605
825
|
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
|
606
|
-
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset,
|
|
826
|
+
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
|
|
607
827
|
}
|
|
608
828
|
}
|
|
609
829
|
|
|
@@ -612,10 +832,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
612
832
|
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
|
|
613
833
|
#pragma unroll
|
|
614
834
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
615
|
-
// Turing + Volta:
|
|
616
835
|
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
|
617
|
-
|
|
618
|
-
|
|
836
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
837
|
+
constexpr int KQ_idx = 0;
|
|
838
|
+
#else
|
|
839
|
+
// Turing + Volta:
|
|
840
|
+
const int KQ_idx = (l/2) % 2;
|
|
841
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
842
|
+
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
|
|
843
|
+
KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
|
|
619
844
|
} else {
|
|
620
845
|
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
|
|
621
846
|
}
|
|
@@ -639,7 +864,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
639
864
|
|
|
640
865
|
#if defined(TURING_MMA_AVAILABLE)
|
|
641
866
|
if constexpr (cols_per_warp == 8) {
|
|
642
|
-
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
|
867
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
|
|
643
868
|
#pragma unroll
|
|
644
869
|
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
|
|
645
870
|
#pragma unroll
|
|
@@ -660,6 +885,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
660
885
|
}
|
|
661
886
|
}
|
|
662
887
|
}
|
|
888
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
889
|
+
if constexpr (std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>) {
|
|
890
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
|
|
891
|
+
#pragma unroll
|
|
892
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
893
|
+
#pragma unroll
|
|
894
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
895
|
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
896
|
+
}
|
|
897
|
+
}
|
|
898
|
+
} else {
|
|
899
|
+
static_assert(std::is_same_v<decltype(T_C_VKQ::x), float[T_C_VKQ::ne]>, "bad VKQ type");
|
|
900
|
+
#pragma unroll
|
|
901
|
+
for (int i = 0; i < DV/T_C_VKQ::J; ++i) {
|
|
902
|
+
#pragma unroll
|
|
903
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
904
|
+
VKQ_C[i].x[l] *= KQ_max_scale[0];
|
|
905
|
+
}
|
|
906
|
+
}
|
|
907
|
+
}
|
|
663
908
|
#else // Volta
|
|
664
909
|
const half2 KQ_max_scale_h2 = make_half2(
|
|
665
910
|
KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
|
|
@@ -688,6 +933,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
688
933
|
}
|
|
689
934
|
|
|
690
935
|
if constexpr (nstages > 1) {
|
|
936
|
+
static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
|
691
937
|
// Preload K tile for next iteration:
|
|
692
938
|
constexpr bool use_cp_async = true;
|
|
693
939
|
cp_async_wait_all();
|
|
@@ -703,19 +949,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
703
949
|
}
|
|
704
950
|
|
|
705
951
|
|
|
706
|
-
// For MLA K and V have the same data.
|
|
707
|
-
// Therefore, iterate over V in reverse and re-use the data if possible.
|
|
708
|
-
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
|
709
|
-
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
|
710
|
-
|
|
711
952
|
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
|
|
712
953
|
#pragma unroll
|
|
713
|
-
for (int
|
|
714
|
-
|
|
715
|
-
const int
|
|
954
|
+
for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
|
|
955
|
+
static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
|
|
956
|
+
const int i0_stop = i0_start + 2*nbatch_V2;
|
|
716
957
|
|
|
717
958
|
if constexpr (nstages <= 1) {
|
|
718
|
-
|
|
959
|
+
const int i0_diff = i0_stop - i0_start;
|
|
960
|
+
if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
|
|
719
961
|
constexpr bool use_cp_async = nstages == 1;
|
|
720
962
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
|
|
721
963
|
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
|
|
@@ -725,12 +967,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
725
967
|
__syncthreads();
|
|
726
968
|
}
|
|
727
969
|
}
|
|
728
|
-
const half2 * tile_V_i =
|
|
970
|
+
const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
|
|
729
971
|
|
|
730
|
-
#if defined(TURING_MMA_AVAILABLE)
|
|
731
|
-
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
|
|
972
|
+
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
732
973
|
#pragma unroll
|
|
733
|
-
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 +=
|
|
974
|
+
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += T_A_VKQ::I) {
|
|
734
975
|
static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size");
|
|
735
976
|
#pragma unroll
|
|
736
977
|
for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) {
|
|
@@ -739,10 +980,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
739
980
|
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
|
|
740
981
|
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
|
741
982
|
if constexpr (T_B_KQ::I == 8) {
|
|
742
|
-
mma(VKQ_C[i_VKQ_0/
|
|
983
|
+
mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]);
|
|
743
984
|
} else {
|
|
744
|
-
// Wide version of VKQ_C is column-major
|
|
745
|
-
|
|
985
|
+
// Wide version of VKQ_C is column-major.
|
|
986
|
+
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
987
|
+
// AMD matrix C is column-major.
|
|
988
|
+
mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]);
|
|
989
|
+
#else
|
|
990
|
+
// swap A and B for CUDA.
|
|
991
|
+
mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], B[k00/(np*T_A_VKQ::J)], A);
|
|
992
|
+
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
746
993
|
}
|
|
747
994
|
}
|
|
748
995
|
}
|
|
@@ -761,7 +1008,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
761
1008
|
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
|
|
762
1009
|
}
|
|
763
1010
|
}
|
|
764
|
-
#endif // defined(TURING_MMA_AVAILABLE)
|
|
1011
|
+
#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
765
1012
|
|
|
766
1013
|
if constexpr (nstages <= 1) {
|
|
767
1014
|
__syncthreads(); // Only needed if tile_K == tile_V.
|
|
@@ -774,11 +1021,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
774
1021
|
tile_Q, tile_K, tile_V, tile_mask,
|
|
775
1022
|
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
|
776
1023
|
NO_DEVICE_CODE;
|
|
777
|
-
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1024
|
+
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
778
1025
|
}
|
|
779
1026
|
|
|
780
1027
|
#if defined(TURING_MMA_AVAILABLE)
|
|
781
|
-
template<int ncols> struct mma_tile_sizes {
|
|
1028
|
+
template<int DV, int ncols> struct mma_tile_sizes {
|
|
782
1029
|
using T_A_KQ = tile<16, 8, half2>; // row-major
|
|
783
1030
|
using T_B_KQ = tile<16, 8, half2>; // column-major
|
|
784
1031
|
using T_C_KQ = tile<16, 16, float>; // column-major
|
|
@@ -786,7 +1033,7 @@ template<int ncols> struct mma_tile_sizes {
|
|
|
786
1033
|
using T_B_VKQ = tile<16, 8, half2>; // column-major
|
|
787
1034
|
using T_C_VKQ = tile<16, 8, half2>; // column-major
|
|
788
1035
|
};
|
|
789
|
-
template
|
|
1036
|
+
template<int DV> struct mma_tile_sizes<DV, 8> {
|
|
790
1037
|
using T_A_KQ = tile<16, 8, half2>; // row-major
|
|
791
1038
|
using T_B_KQ = tile< 8, 8, half2>; // column-major
|
|
792
1039
|
using T_C_KQ = tile<16, 8, float>; // row-major
|
|
@@ -794,8 +1041,69 @@ template<> struct mma_tile_sizes<8> {
|
|
|
794
1041
|
using T_B_VKQ = tile< 8, 8, half2>; // column-major
|
|
795
1042
|
using T_C_VKQ = tile<16, 4, half2>; // row-major
|
|
796
1043
|
};
|
|
1044
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
1045
|
+
#ifdef RDNA3
|
|
1046
|
+
template<int DV, int ncols> struct mma_tile_sizes {
|
|
1047
|
+
using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
|
1048
|
+
using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
|
|
1049
|
+
using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
1050
|
+
using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
|
1051
|
+
using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
|
|
1052
|
+
using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
1053
|
+
};
|
|
1054
|
+
template<int ncols> struct mma_tile_sizes<80, ncols> {
|
|
1055
|
+
using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
|
1056
|
+
using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
|
|
1057
|
+
using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
1058
|
+
using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
|
1059
|
+
using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
|
|
1060
|
+
using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
1061
|
+
};
|
|
1062
|
+
template<int ncols> struct mma_tile_sizes<112, ncols> {
|
|
1063
|
+
using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
|
1064
|
+
using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
|
|
1065
|
+
using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
1066
|
+
using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
|
1067
|
+
using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
|
|
1068
|
+
using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
1069
|
+
};
|
|
1070
|
+
#else
|
|
1071
|
+
template<int DV, int ncols> struct mma_tile_sizes {
|
|
1072
|
+
using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major
|
|
1073
|
+
using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
1074
|
+
using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
1075
|
+
using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major
|
|
1076
|
+
using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
1077
|
+
using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED>; // column-major
|
|
1078
|
+
};
|
|
1079
|
+
template<int ncols> struct mma_tile_sizes<80, ncols> {
|
|
1080
|
+
using T_A_KQ = tile<16, 8, half2>; // row-major
|
|
1081
|
+
using T_B_KQ = tile<16, 8, half2>; // column-major
|
|
1082
|
+
using T_C_KQ = tile<16, 16, float>; // column-major
|
|
1083
|
+
using T_A_VKQ = tile<16, 8, half2>; // row-major
|
|
1084
|
+
using T_B_VKQ = tile<16, 8, half2>; // column-major
|
|
1085
|
+
using T_C_VKQ = tile<16, 8, half2>; // column-major
|
|
1086
|
+
};
|
|
1087
|
+
template<int ncols> struct mma_tile_sizes<112, ncols> {
|
|
1088
|
+
using T_A_KQ = tile<16, 8, half2>; // row-major
|
|
1089
|
+
using T_B_KQ = tile<16, 8, half2>; // column-major
|
|
1090
|
+
using T_C_KQ = tile<16, 16, float>; // column-major
|
|
1091
|
+
using T_A_VKQ = tile<16, 8, half2>; // row-major
|
|
1092
|
+
using T_B_VKQ = tile<16, 8, half2>; // column-major
|
|
1093
|
+
using T_C_VKQ = tile<16, 8, half2>; // column-major
|
|
1094
|
+
};
|
|
1095
|
+
#endif // RDNA3
|
|
1096
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
1097
|
+
template<int DV, int ncols> struct mma_tile_sizes {
|
|
1098
|
+
using T_A_KQ = tile<16, 8, half2>; // row-major
|
|
1099
|
+
using T_B_KQ = tile<16, 8, half2>; // column-major
|
|
1100
|
+
using T_C_KQ = tile<16, 16, float>; // column-major
|
|
1101
|
+
using T_A_VKQ = tile<16, 8, half2>; // row-major
|
|
1102
|
+
using T_B_VKQ = tile<16, 8, half2>; // column-major
|
|
1103
|
+
using T_C_VKQ = tile<16, 8, half2>; // column-major
|
|
1104
|
+
};
|
|
797
1105
|
#else // Volta
|
|
798
|
-
template<int ncols> struct mma_tile_sizes {
|
|
1106
|
+
template<int DV, int ncols> struct mma_tile_sizes {
|
|
799
1107
|
using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
|
800
1108
|
using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
801
1109
|
using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major
|
|
@@ -805,7 +1113,7 @@ template<int ncols> struct mma_tile_sizes {
|
|
|
805
1113
|
};
|
|
806
1114
|
#endif // defined(TURING_MMA_AVAILABLE)
|
|
807
1115
|
|
|
808
|
-
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool
|
|
1116
|
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>
|
|
809
1117
|
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
810
1118
|
const float2 * const __restrict__ Q_f2,
|
|
811
1119
|
const half2 * const __restrict__ K_h2,
|
|
@@ -819,6 +1127,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
819
1127
|
const float logit_softcap,
|
|
820
1128
|
const uint3 ne01,
|
|
821
1129
|
const int ne02,
|
|
1130
|
+
const int gqa_ratio,
|
|
822
1131
|
const int ne11,
|
|
823
1132
|
const int stride_Q1,
|
|
824
1133
|
const int stride_Q2,
|
|
@@ -826,22 +1135,24 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
826
1135
|
const int stride_V,
|
|
827
1136
|
const int stride_mask,
|
|
828
1137
|
const int jt,
|
|
1138
|
+
const int zt_gqa,
|
|
829
1139
|
const int kb0_start,
|
|
830
1140
|
const int kb0_stop) {
|
|
831
|
-
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1141
|
+
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
832
1142
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
|
833
1143
|
|
|
1144
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
834
1145
|
constexpr int ncols = ncols1 * ncols2;
|
|
835
|
-
using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
|
|
836
|
-
using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
|
|
837
|
-
using T_C_KQ = typename mma_tile_sizes<ncols>::T_C_KQ;
|
|
838
|
-
using T_A_VKQ = typename mma_tile_sizes<ncols>::T_A_VKQ;
|
|
839
|
-
using T_B_VKQ = typename mma_tile_sizes<ncols>::T_B_VKQ;
|
|
840
|
-
using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ;
|
|
1146
|
+
using T_A_KQ = typename mma_tile_sizes<DV, ncols>::T_A_KQ;
|
|
1147
|
+
using T_B_KQ = typename mma_tile_sizes<DV, ncols>::T_B_KQ;
|
|
1148
|
+
using T_C_KQ = typename mma_tile_sizes<DV, ncols>::T_C_KQ;
|
|
1149
|
+
using T_A_VKQ = typename mma_tile_sizes<DV, ncols>::T_A_VKQ;
|
|
1150
|
+
using T_B_VKQ = typename mma_tile_sizes<DV, ncols>::T_B_VKQ;
|
|
1151
|
+
using T_C_VKQ = typename mma_tile_sizes<DV, ncols>::T_C_VKQ;
|
|
841
1152
|
|
|
842
1153
|
constexpr int cols_per_warp = T_B_KQ::I;
|
|
843
|
-
constexpr int cols_per_thread =
|
|
844
|
-
constexpr int np = nwarps *
|
|
1154
|
+
constexpr int cols_per_thread = get_cols_per_thread();
|
|
1155
|
+
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
|
845
1156
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
|
846
1157
|
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
|
847
1158
|
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
|
@@ -859,8 +1170,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
859
1170
|
constexpr int stride_tile_Q = DKQ/2 + 4;
|
|
860
1171
|
constexpr int stride_tile_K = nbatch_K2 + 4;
|
|
861
1172
|
|
|
862
|
-
|
|
863
|
-
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
|
1173
|
+
constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
|
864
1174
|
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
|
865
1175
|
|
|
866
1176
|
extern __shared__ half2 tile_Q[];
|
|
@@ -871,6 +1181,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
871
1181
|
T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
|
|
872
1182
|
#if defined(TURING_MMA_AVAILABLE)
|
|
873
1183
|
T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
|
|
1184
|
+
#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3)
|
|
1185
|
+
T_C_VKQ VKQ_C[DV % 32 != 0 ? DV/T_C_VKQ::J : DV/(2*T_C_VKQ::J)];
|
|
1186
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
1187
|
+
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
|
874
1188
|
#else // Volta
|
|
875
1189
|
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
|
876
1190
|
#endif // defined(TURING_MMA_AVAILABLE)
|
|
@@ -887,10 +1201,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
887
1201
|
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
|
888
1202
|
const half2 scale_h2 = make_half2(scale, scale);
|
|
889
1203
|
#pragma unroll
|
|
890
|
-
for (int stride_k : {
|
|
891
|
-
const int k0_start = stride_k ==
|
|
1204
|
+
for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
|
|
1205
|
+
const int k0_start = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
|
|
892
1206
|
const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
|
|
893
|
-
const int stride_jc =
|
|
1207
|
+
const int stride_jc = warp_size / stride_k;
|
|
894
1208
|
|
|
895
1209
|
if (k0_start == k0_stop) {
|
|
896
1210
|
continue;
|
|
@@ -898,7 +1212,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
898
1212
|
|
|
899
1213
|
#pragma unroll
|
|
900
1214
|
for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
|
|
901
|
-
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k ==
|
|
1215
|
+
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
|
902
1216
|
|
|
903
1217
|
if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
|
|
904
1218
|
break;
|
|
@@ -907,10 +1221,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
907
1221
|
const int j = jc / ncols2;
|
|
908
1222
|
const int c = jc % ncols2;
|
|
909
1223
|
|
|
910
|
-
if (jt*ncols1 + j < int(ne01.z)) {
|
|
1224
|
+
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
|
|
911
1225
|
#pragma unroll
|
|
912
1226
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
913
|
-
const int k = k0 + (stride_k ==
|
|
1227
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
914
1228
|
|
|
915
1229
|
const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
|
|
916
1230
|
tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
|
@@ -918,7 +1232,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
918
1232
|
} else {
|
|
919
1233
|
#pragma unroll
|
|
920
1234
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
921
|
-
const int k = k0 + (stride_k ==
|
|
1235
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
922
1236
|
|
|
923
1237
|
tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
|
|
924
1238
|
}
|
|
@@ -962,7 +1276,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
962
1276
|
constexpr bool last_iter = false;
|
|
963
1277
|
constexpr int k_VKQ_sup = nbatch_fa;
|
|
964
1278
|
flash_attn_ext_f16_iter
|
|
965
|
-
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1279
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
966
1280
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
967
1281
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
968
1282
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
@@ -971,7 +1285,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
971
1285
|
constexpr bool last_iter = true;
|
|
972
1286
|
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
|
|
973
1287
|
flash_attn_ext_f16_iter
|
|
974
|
-
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1288
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
975
1289
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
976
1290
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
977
1291
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
@@ -982,7 +1296,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
982
1296
|
constexpr bool last_iter = false;
|
|
983
1297
|
constexpr int k_VKQ_sup = nbatch_fa;
|
|
984
1298
|
flash_attn_ext_f16_iter
|
|
985
|
-
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1299
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
986
1300
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
987
1301
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
988
1302
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
@@ -991,7 +1305,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
991
1305
|
constexpr bool last_iter = true;
|
|
992
1306
|
constexpr int k_VKQ_sup = nbatch_fa;
|
|
993
1307
|
flash_attn_ext_f16_iter
|
|
994
|
-
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1308
|
+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
|
995
1309
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
996
1310
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
997
1311
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
@@ -1010,6 +1324,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1010
1324
|
// The partial sums are spread across 8/4 threads.
|
|
1011
1325
|
constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
|
|
1012
1326
|
constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
|
|
1327
|
+
#elif defined(AMD_MFMA_AVAILABLE)
|
|
1328
|
+
// The partial sums are spread across 4 threads (wavefront64, 16 cols).
|
|
1329
|
+
constexpr int offset_first = 32;
|
|
1330
|
+
constexpr int offset_last = 16;
|
|
1331
|
+
#elif defined(AMD_WMMA_AVAILABLE)
|
|
1332
|
+
// The partial sums are spread across 2 threads.
|
|
1333
|
+
constexpr int offset_first = 16;
|
|
1334
|
+
constexpr int offset_last = 16;
|
|
1013
1335
|
#else // Volta
|
|
1014
1336
|
// The partial sums are spread across 2 threads.
|
|
1015
1337
|
constexpr int offset_first = 2;
|
|
@@ -1019,19 +1341,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1019
1341
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
1020
1342
|
#pragma unroll
|
|
1021
1343
|
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
|
1022
|
-
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset,
|
|
1344
|
+
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size);
|
|
1023
1345
|
}
|
|
1024
1346
|
}
|
|
1025
1347
|
}
|
|
1026
1348
|
|
|
1027
1349
|
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
|
1028
|
-
// Also add the sink as a value to KQ_rowsum, this is done after
|
|
1350
|
+
// Also add the sink as a value to KQ_rowsum, this is done after synchronization of KQ_rowsum
|
|
1029
1351
|
// so it's being done unconditionally for every thread.
|
|
1030
1352
|
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
|
|
1031
1353
|
float KQ_max_scale[cols_per_thread];
|
|
1032
1354
|
#pragma unroll
|
|
1033
1355
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
1034
|
-
const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
|
|
1356
|
+
const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col));
|
|
1035
1357
|
const float sink = sinks_f[jc % ncols2];
|
|
1036
1358
|
|
|
1037
1359
|
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
|
@@ -1047,7 +1369,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1047
1369
|
|
|
1048
1370
|
#if defined(TURING_MMA_AVAILABLE)
|
|
1049
1371
|
if constexpr (cols_per_warp == 8) {
|
|
1050
|
-
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
|
1372
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
|
|
1051
1373
|
#pragma unroll
|
|
1052
1374
|
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
|
|
1053
1375
|
#pragma unroll
|
|
@@ -1068,6 +1390,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1068
1390
|
}
|
|
1069
1391
|
}
|
|
1070
1392
|
}
|
|
1393
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
1394
|
+
if constexpr (std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>) {
|
|
1395
|
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
|
|
1396
|
+
#pragma unroll
|
|
1397
|
+
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
1398
|
+
#pragma unroll
|
|
1399
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
1400
|
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
1401
|
+
}
|
|
1402
|
+
}
|
|
1403
|
+
} else {
|
|
1404
|
+
static_assert(std::is_same_v<decltype(T_C_VKQ::x), float[T_C_VKQ::ne]>, "bad VKQ type");
|
|
1405
|
+
#pragma unroll
|
|
1406
|
+
for (int i = 0; i < DV/T_C_VKQ::J; ++i) {
|
|
1407
|
+
#pragma unroll
|
|
1408
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
1409
|
+
VKQ_C[i].x[l] *= KQ_max_scale[0];
|
|
1410
|
+
}
|
|
1411
|
+
}
|
|
1412
|
+
}
|
|
1071
1413
|
#else // Volta
|
|
1072
1414
|
const int col = (threadIdx.x / 2) % 2;
|
|
1073
1415
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
|
@@ -1119,6 +1461,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1119
1461
|
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
|
|
1120
1462
|
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
|
|
1121
1463
|
const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
|
|
1464
|
+
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
1465
|
+
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
|
|
1466
|
+
const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
|
|
1467
|
+
const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
|
|
1122
1468
|
#else // Volta
|
|
1123
1469
|
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
|
|
1124
1470
|
const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
|
|
@@ -1149,14 +1495,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1149
1495
|
// Warps with threadIdx.y % np != 0 must NOT return early.
|
|
1150
1496
|
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
|
1151
1497
|
|
|
1152
|
-
constexpr int nmeta = np*cols_per_warp >=
|
|
1498
|
+
constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;
|
|
1153
1499
|
|
|
1154
|
-
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp <
|
|
1500
|
+
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
|
|
1155
1501
|
float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
|
|
1156
1502
|
float2 meta[nmeta];
|
|
1157
1503
|
#pragma unroll
|
|
1158
1504
|
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
|
1159
|
-
meta[imeta] = meta_ptr[imeta *
|
|
1505
|
+
meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
|
|
1160
1506
|
}
|
|
1161
1507
|
|
|
1162
1508
|
float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
|
|
@@ -1166,8 +1512,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1166
1512
|
}
|
|
1167
1513
|
#pragma unroll
|
|
1168
1514
|
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
|
1169
|
-
if (offset <
|
|
1170
|
-
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset,
|
|
1515
|
+
if (offset < warp_size) {
|
|
1516
|
+
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
|
|
1171
1517
|
}
|
|
1172
1518
|
}
|
|
1173
1519
|
|
|
@@ -1184,8 +1530,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1184
1530
|
}
|
|
1185
1531
|
#pragma unroll
|
|
1186
1532
|
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
|
1187
|
-
if (offset <
|
|
1188
|
-
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset,
|
|
1533
|
+
if (offset < warp_size) {
|
|
1534
|
+
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
|
|
1189
1535
|
}
|
|
1190
1536
|
}
|
|
1191
1537
|
|
|
@@ -1194,19 +1540,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1194
1540
|
// Write back combined meta data:
|
|
1195
1541
|
#pragma unroll
|
|
1196
1542
|
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
|
1197
|
-
if (np*cols_per_warp >=
|
|
1543
|
+
if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
|
|
1198
1544
|
// Combined KQ max scale + rowsum.
|
|
1199
|
-
meta_ptr[imeta *
|
|
1545
|
+
meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
|
|
1200
1546
|
}
|
|
1201
1547
|
}
|
|
1202
1548
|
|
|
1203
1549
|
// Combined KQ max + rowsum.
|
|
1204
|
-
static_assert(cols_per_warp <=
|
|
1205
|
-
if (needs_fixup && (cols_per_warp ==
|
|
1550
|
+
static_assert(cols_per_warp <= warp_size);
|
|
1551
|
+
if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
|
|
1206
1552
|
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
|
1207
1553
|
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
|
1208
1554
|
}
|
|
1209
|
-
if (is_fixup && (cols_per_warp ==
|
|
1555
|
+
if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
|
|
1210
1556
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
|
1211
1557
|
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
|
1212
1558
|
}
|
|
@@ -1220,6 +1566,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1220
1566
|
#pragma unroll
|
|
1221
1567
|
for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
|
|
1222
1568
|
if constexpr (cols_per_warp == 8) {
|
|
1569
|
+
static_assert(std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>, "bad VKQ type");
|
|
1223
1570
|
const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data
|
|
1224
1571
|
#pragma unroll
|
|
1225
1572
|
for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) {
|
|
@@ -1234,14 +1581,45 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1234
1581
|
}
|
|
1235
1582
|
} else {
|
|
1236
1583
|
const int j0 = threadIdx.y*cols_per_warp;
|
|
1584
|
+
if constexpr (std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>) {
|
|
1585
|
+
if constexpr (T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR) {
|
|
1237
1586
|
#pragma unroll
|
|
1238
|
-
|
|
1587
|
+
for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
|
|
1239
1588
|
#pragma unroll
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1589
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
1590
|
+
const int j = j0 + T_C_VKQ::get_i(l);
|
|
1591
|
+
const int k = k1 + T_C_VKQ::get_j(l);
|
|
1592
|
+
|
|
1593
|
+
tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];
|
|
1594
|
+
}
|
|
1595
|
+
}
|
|
1596
|
+
} else {
|
|
1597
|
+
static_assert(T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR_SCRAMBLED, "bad T_C_VKQ data layout");
|
|
1598
|
+
using T_C_VKQ_us = tile<T_C_VKQ::I, T_C_VKQ::J, half2, DATA_LAYOUT_I_MAJOR>; // us == unscrambled
|
|
1599
|
+
#pragma unroll
|
|
1600
|
+
for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
|
|
1601
|
+
const T_C_VKQ_us VKQ_C_us = unscramble(VKQ_C[(k00 + k1)/T_C_VKQ::J]);
|
|
1602
|
+
#pragma unroll
|
|
1603
|
+
for (int l = 0; l < T_C_VKQ_us::ne; ++l) {
|
|
1604
|
+
const int j = j0 + T_C_VKQ_us::get_i(l);
|
|
1605
|
+
const int k = k1 + T_C_VKQ_us::get_j(l);
|
|
1243
1606
|
|
|
1244
|
-
|
|
1607
|
+
tile_Q[j*tile_stride + k] = VKQ_C_us.x[l];
|
|
1608
|
+
}
|
|
1609
|
+
}
|
|
1610
|
+
}
|
|
1611
|
+
} else {
|
|
1612
|
+
static_assert(std::is_same_v<decltype(T_C_VKQ::x), float[T_C_VKQ::ne]>, "bad VKQ type");
|
|
1613
|
+
half * tile_Q_h = (half *) tile_Q;
|
|
1614
|
+
#pragma unroll
|
|
1615
|
+
for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J/2) {
|
|
1616
|
+
#pragma unroll
|
|
1617
|
+
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
1618
|
+
const int j = j0 + T_C_VKQ::get_i(l);
|
|
1619
|
+
const int k = 2*k1 + T_C_VKQ::get_j(l);
|
|
1620
|
+
|
|
1621
|
+
tile_Q_h[j*(2*tile_stride) + k] = VKQ_C[(k00 + k1)/(T_C_VKQ::J/2)].x[l];
|
|
1622
|
+
}
|
|
1245
1623
|
}
|
|
1246
1624
|
}
|
|
1247
1625
|
}
|
|
@@ -1254,10 +1632,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1254
1632
|
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
|
|
1255
1633
|
|
|
1256
1634
|
#pragma unroll
|
|
1257
|
-
for (int stride_k : {
|
|
1258
|
-
const int k0_start = stride_k ==
|
|
1635
|
+
for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
|
|
1636
|
+
const int k0_start = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
|
|
1259
1637
|
const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
|
|
1260
|
-
const int stride_jc =
|
|
1638
|
+
const int stride_jc = warp_size / stride_k;
|
|
1261
1639
|
|
|
1262
1640
|
if (k0_start == k0_stop) {
|
|
1263
1641
|
continue;
|
|
@@ -1265,7 +1643,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1265
1643
|
|
|
1266
1644
|
#pragma unroll
|
|
1267
1645
|
for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
|
|
1268
|
-
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k ==
|
|
1646
|
+
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
|
1269
1647
|
|
|
1270
1648
|
if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
|
|
1271
1649
|
break;
|
|
@@ -1276,14 +1654,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1276
1654
|
const int j_dst = jc_dst / ncols2;
|
|
1277
1655
|
const int c_dst = jc_dst % ncols2;
|
|
1278
1656
|
|
|
1279
|
-
if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
|
|
1657
|
+
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
|
|
1280
1658
|
continue;
|
|
1281
1659
|
}
|
|
1282
1660
|
|
|
1283
1661
|
const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
|
|
1284
1662
|
#pragma unroll
|
|
1285
1663
|
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
|
1286
|
-
const int k = k0 + (stride_k ==
|
|
1664
|
+
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
|
1287
1665
|
|
|
1288
1666
|
float2 dstk_val = make_float2(0.0f, 0.0f);
|
|
1289
1667
|
#pragma unroll
|
|
@@ -1315,24 +1693,24 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1315
1693
|
}
|
|
1316
1694
|
#else
|
|
1317
1695
|
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
|
|
1318
|
-
scale, slope, logit_softcap, ne01, ne02,
|
|
1696
|
+
scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
|
|
1319
1697
|
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
|
|
1320
1698
|
jt, kb0_start, kb0_stop);
|
|
1321
1699
|
NO_DEVICE_CODE;
|
|
1322
|
-
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
1700
|
+
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
1323
1701
|
}
|
|
1324
1702
|
|
|
1325
|
-
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool
|
|
1703
|
+
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
|
|
1326
1704
|
__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
|
|
1327
1705
|
static __global__ void flash_attn_ext_f16(
|
|
1328
|
-
const char *
|
|
1329
|
-
const char *
|
|
1330
|
-
const char *
|
|
1331
|
-
const char *
|
|
1332
|
-
const char *
|
|
1333
|
-
const int *
|
|
1334
|
-
float *
|
|
1335
|
-
float2 *
|
|
1706
|
+
const char * Q_ptr,
|
|
1707
|
+
const char * K_ptr,
|
|
1708
|
+
const char * V_ptr,
|
|
1709
|
+
const char * mask_ptr,
|
|
1710
|
+
const char * sinks_ptr,
|
|
1711
|
+
const int * KV_max_ptr,
|
|
1712
|
+
float * dst_ptr,
|
|
1713
|
+
float2 * dst_meta_ptr,
|
|
1336
1714
|
const float scale,
|
|
1337
1715
|
const float max_bias,
|
|
1338
1716
|
const float m0,
|
|
@@ -1346,13 +1724,33 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1346
1724
|
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
1347
1725
|
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
1348
1726
|
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
1349
|
-
|
|
1727
|
+
ggml_cuda_pdl_sync(); // TODO optimize placement
|
|
1728
|
+
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE))
|
|
1729
|
+
const char * GGML_CUDA_RESTRICT Q = Q_ptr;
|
|
1730
|
+
const char * GGML_CUDA_RESTRICT K = K_ptr;
|
|
1731
|
+
const char * GGML_CUDA_RESTRICT V = V_ptr;
|
|
1732
|
+
const char * GGML_CUDA_RESTRICT mask = mask_ptr;
|
|
1733
|
+
const char * GGML_CUDA_RESTRICT sinks = sinks_ptr;
|
|
1734
|
+
const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr;
|
|
1735
|
+
float * GGML_CUDA_RESTRICT dst = dst_ptr;
|
|
1736
|
+
float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr;
|
|
1350
1737
|
|
|
1351
1738
|
// Skip unused kernel variants for faster compilation:
|
|
1352
|
-
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
|
1739
|
+
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) {
|
|
1353
1740
|
NO_DEVICE_CODE;
|
|
1354
1741
|
return;
|
|
1355
1742
|
}
|
|
1743
|
+
if (DKQ == 192 && ncols2 != 8 && ncols2 != 16) {
|
|
1744
|
+
NO_DEVICE_CODE;
|
|
1745
|
+
return;
|
|
1746
|
+
}
|
|
1747
|
+
#ifdef VOLTA_MMA_AVAILABLE
|
|
1748
|
+
if (ncols1*ncols2 < 32) {
|
|
1749
|
+
NO_DEVICE_CODE;
|
|
1750
|
+
return;
|
|
1751
|
+
}
|
|
1752
|
+
#endif // VOLTA_MMA_AVAILABLE
|
|
1753
|
+
|
|
1356
1754
|
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
1357
1755
|
if (ncols1*ncols2 > 32) {
|
|
1358
1756
|
NO_DEVICE_CODE;
|
|
@@ -1360,12 +1758,25 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1360
1758
|
}
|
|
1361
1759
|
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
1362
1760
|
|
|
1363
|
-
|
|
1761
|
+
#if defined(AMD_WMMA_AVAILABLE)
|
|
1762
|
+
if (ncols1*ncols2 < 16 || ncols2 == 1 || DKQ > 128) {
|
|
1763
|
+
NO_DEVICE_CODE;
|
|
1764
|
+
return;
|
|
1765
|
+
}
|
|
1766
|
+
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
1767
|
+
|
|
1768
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
1769
|
+
if (ncols1*ncols2 < 16 || DKQ > 256) {
|
|
1770
|
+
NO_DEVICE_CODE;
|
|
1771
|
+
return;
|
|
1772
|
+
}
|
|
1773
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
1364
1774
|
|
|
1775
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1365
1776
|
constexpr int ncols = ncols1 * ncols2;
|
|
1366
1777
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
|
1367
1778
|
constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
|
|
1368
|
-
constexpr int nwarps = nthreads /
|
|
1779
|
+
constexpr int nwarps = nthreads / warp_size;
|
|
1369
1780
|
|
|
1370
1781
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
1371
1782
|
|
|
@@ -1374,14 +1785,15 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1374
1785
|
const int stride_K = nb11 / sizeof(half2);
|
|
1375
1786
|
const int stride_mask = nb31 / sizeof(half);
|
|
1376
1787
|
|
|
1377
|
-
const int stride_V =
|
|
1788
|
+
const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
|
|
1378
1789
|
|
|
1379
|
-
const int iter_k
|
|
1380
|
-
const int iter_j
|
|
1790
|
+
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
|
1791
|
+
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
|
1792
|
+
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
|
|
1381
1793
|
|
|
1382
1794
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
1383
|
-
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*
|
|
1384
|
-
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*
|
|
1795
|
+
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
|
1796
|
+
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
|
1385
1797
|
|
|
1386
1798
|
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
|
1387
1799
|
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
|
@@ -1392,22 +1804,24 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1392
1804
|
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
|
1393
1805
|
|
|
1394
1806
|
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
|
1395
|
-
|
|
1396
|
-
const int
|
|
1397
|
-
const int
|
|
1807
|
+
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
|
|
1808
|
+
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
|
|
1809
|
+
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
|
1810
|
+
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
|
1811
|
+
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
|
1398
1812
|
|
|
1399
|
-
const int
|
|
1813
|
+
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
|
1400
1814
|
|
|
1401
|
-
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*
|
|
1402
|
-
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*
|
|
1815
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
|
|
1816
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
|
|
1403
1817
|
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
|
|
1404
1818
|
(const half *) (mask + nb33*(sequence % ne33));
|
|
1405
|
-
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 +
|
|
1819
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
|
|
1406
1820
|
|
|
1407
|
-
const half2 * V_h2 =
|
|
1408
|
-
const float * sinks_f = sinks ? (const float *) sinks +
|
|
1821
|
+
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
|
|
1822
|
+
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
|
|
1409
1823
|
|
|
1410
|
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
|
1824
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
|
|
1411
1825
|
|
|
1412
1826
|
if (KV_max) {
|
|
1413
1827
|
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
|
|
@@ -1415,14 +1829,14 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1415
1829
|
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
|
1416
1830
|
if (kb0_start == 0) {
|
|
1417
1831
|
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
|
1418
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1832
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
1419
1833
|
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1420
|
-
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
1834
|
+
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
|
|
1421
1835
|
} else {
|
|
1422
1836
|
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
|
|
1423
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1837
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
1424
1838
|
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1425
|
-
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
1839
|
+
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
|
|
1426
1840
|
}
|
|
1427
1841
|
|
|
1428
1842
|
kbc += iter_k;
|
|
@@ -1436,22 +1850,24 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1436
1850
|
return;
|
|
1437
1851
|
}
|
|
1438
1852
|
|
|
1439
|
-
|
|
1440
|
-
const int
|
|
1441
|
-
const int
|
|
1853
|
+
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
|
|
1854
|
+
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
|
|
1855
|
+
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
|
1856
|
+
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
|
1857
|
+
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
|
1442
1858
|
|
|
1443
|
-
const int
|
|
1859
|
+
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
|
1444
1860
|
|
|
1445
|
-
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*
|
|
1446
|
-
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*
|
|
1861
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
|
|
1862
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
|
|
1447
1863
|
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
|
|
1448
1864
|
(const half *) (mask + nb33*(sequence % ne33));
|
|
1449
|
-
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 +
|
|
1865
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
|
|
1450
1866
|
|
|
1451
|
-
const half2 * V_h2 =
|
|
1452
|
-
const float * sinks_f = sinks ? (const float *) sinks +
|
|
1867
|
+
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
|
|
1868
|
+
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
|
|
1453
1869
|
|
|
1454
|
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
|
1870
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
|
|
1455
1871
|
|
|
1456
1872
|
if (KV_max) {
|
|
1457
1873
|
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
|
|
@@ -1459,11 +1875,11 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1459
1875
|
|
|
1460
1876
|
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
|
1461
1877
|
constexpr bool needs_fixup = false;
|
|
1462
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap,
|
|
1878
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
|
1463
1879
|
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1464
|
-
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
|
1880
|
+
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
|
|
1465
1881
|
#else
|
|
1466
|
-
GGML_UNUSED_VARS(
|
|
1882
|
+
GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale,
|
|
1467
1883
|
max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
1468
1884
|
ne00, ne01, ne02, ne03,
|
|
1469
1885
|
nb01, nb02, nb03,
|
|
@@ -1473,7 +1889,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1473
1889
|
ne31, ne32, ne33,
|
|
1474
1890
|
nb31, nb32, nb33);
|
|
1475
1891
|
NO_DEVICE_CODE;
|
|
1476
|
-
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
|
|
1892
|
+
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE))
|
|
1477
1893
|
}
|
|
1478
1894
|
|
|
1479
1895
|
template <int DKQ, int DV, int ncols1, int ncols2>
|
|
@@ -1492,10 +1908,11 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
1492
1908
|
const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc);
|
|
1493
1909
|
const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
|
|
1494
1910
|
|
|
1495
|
-
const int cols_per_warp = std::min(ncols,
|
|
1496
|
-
const int
|
|
1911
|
+
const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
|
|
1912
|
+
const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size;
|
|
1913
|
+
const int nwarps = nthreads / warp_size_host;
|
|
1497
1914
|
|
|
1498
|
-
constexpr bool
|
|
1915
|
+
constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
|
|
1499
1916
|
|
|
1500
1917
|
const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
|
1501
1918
|
const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
|
@@ -1512,33 +1929,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
1512
1929
|
float logit_softcap;
|
|
1513
1930
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
|
1514
1931
|
|
|
1932
|
+
#if defined(GGML_USE_HIP)
|
|
1933
|
+
using fattn_kernel_ptr_t = const void*;
|
|
1934
|
+
#else
|
|
1935
|
+
using fattn_kernel_ptr_t = fattn_kernel_t;
|
|
1936
|
+
#endif // defined(GGML_USE_HIP)
|
|
1515
1937
|
fattn_kernel_t fattn_kernel;
|
|
1516
1938
|
if (logit_softcap == 0.0f) {
|
|
1517
1939
|
constexpr bool use_logit_softcap = false;
|
|
1518
|
-
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap,
|
|
1940
|
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
|
1519
1941
|
|
|
1520
|
-
#if !defined(
|
|
1942
|
+
#if !defined(GGML_USE_MUSA)
|
|
1521
1943
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
1522
1944
|
if (!shared_memory_limit_raised[id]) {
|
|
1523
|
-
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1945
|
+
CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1524
1946
|
shared_memory_limit_raised[id] = true;
|
|
1525
1947
|
}
|
|
1526
|
-
#endif // !defined(
|
|
1948
|
+
#endif // !defined(GGML_USE_MUSA)
|
|
1527
1949
|
} else {
|
|
1528
1950
|
constexpr bool use_logit_softcap = true;
|
|
1529
|
-
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap,
|
|
1951
|
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
|
1530
1952
|
|
|
1531
|
-
#if !defined(
|
|
1953
|
+
#if !defined(GGML_USE_MUSA)
|
|
1532
1954
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
1533
1955
|
if (!shared_memory_limit_raised[id]) {
|
|
1534
|
-
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1956
|
+
CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
1535
1957
|
shared_memory_limit_raised[id] = true;
|
|
1536
1958
|
}
|
|
1537
|
-
#endif // !defined(
|
|
1959
|
+
#endif // !defined(GGML_USE_MUSA)
|
|
1538
1960
|
}
|
|
1539
1961
|
|
|
1540
1962
|
launch_fattn<DV, ncols1, ncols2>
|
|
1541
|
-
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
|
|
1963
|
+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host);
|
|
1542
1964
|
}
|
|
1543
1965
|
|
|
1544
1966
|
|
|
@@ -1581,7 +2003,27 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
|
|
|
1581
2003
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
|
|
1582
2004
|
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
|
1583
2005
|
|
|
2006
|
+
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
|
|
2007
|
+
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
|
|
2008
|
+
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
|
|
2009
|
+
extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4);
|
|
2010
|
+
extern DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);
|
|
2011
|
+
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);
|
|
2012
|
+
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);
|
|
2013
|
+
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8);
|
|
2014
|
+
|
|
1584
2015
|
// The number of viable configurations for Deepseek is very limited:
|
|
1585
2016
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
|
1586
2017
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
|
1587
2018
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
|
2019
|
+
|
|
2020
|
+
// Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build:
|
|
2021
|
+
extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
|
|
2022
|
+
extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
|
|
2023
|
+
|
|
2024
|
+
// For GLM 4.7 Flash
|
|
2025
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
|
2026
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
|
2027
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
|
2028
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
|
|
2029
|
+
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
|