whispercpp 1.3.6 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/README.md +38 -5
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -8
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +36 -42
- data/ext/ruby_whisper.h +135 -0
- data/ext/ruby_whisper_context.c +107 -28
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -65
- data/ext/ruby_whisper_segment.c +6 -6
- data/ext/ruby_whisper_transcribe.cpp +42 -15
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +1 -1
- data/ext/sources/examples/cli/cli.cpp +43 -9
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +199 -163
- data/ext/sources/ggml/CMakeLists.txt +21 -13
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +72 -10
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-rpc.h +3 -3
- data/ext/sources/ggml/include/ggml.h +101 -9
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +22 -5
- data/ext/sources/ggml/src/ggml-alloc.c +5 -1
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
- data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
- data/ext/sources/ggml/src/ggml-impl.h +6 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
- data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +289 -114
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
- data/ext/sources/ggml/src/ggml.c +110 -28
- data/ext/sources/ggml/src/gguf.cpp +173 -28
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +56 -12
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +411 -62
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +24 -6
- data/whispercpp.gemspec +2 -2
- metadata +215 -281
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
- data/ext/sources/examples/talk-llama/llama-context.h +0 -359
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
- data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
- data/ext/sources/examples/talk-llama/llama-model.h +0 -597
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
- data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
- data/ext/sources/examples/talk-llama/llama.h +0 -1573
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -704
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
- /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
|
@@ -21,6 +21,33 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
|
|
21
21
|
|
|
22
22
|
#include <vulkan/vulkan.hpp>
|
|
23
23
|
|
|
24
|
+
// Fallback definitions for VK_NV_cooperative_matrix_decode_vector in case the
|
|
25
|
+
// installed Vulkan headers predate the extension.
|
|
26
|
+
#ifndef VK_NV_cooperative_matrix_decode_vector
|
|
27
|
+
#define VK_NV_cooperative_matrix_decode_vector 1
|
|
28
|
+
#define VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME "VK_NV_cooperative_matrix_decode_vector"
|
|
29
|
+
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV ((VkStructureType)1000689000)
|
|
30
|
+
typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV {
|
|
31
|
+
VkStructureType sType;
|
|
32
|
+
void* pNext;
|
|
33
|
+
VkBool32 cooperativeMatrixDecodeVector;
|
|
34
|
+
} VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV;
|
|
35
|
+
#endif
|
|
36
|
+
|
|
37
|
+
// SPIR-V Headers: different SDK installations expose different include paths.
|
|
38
|
+
// LunarG Vulkan SDK on Windows typically provides <spirv-headers/spirv.hpp>.
|
|
39
|
+
// Linux packages, MSYS2 and MinGW often use the Khronos layout <spirv/unified1/spirv.hpp>.
|
|
40
|
+
#if __has_include(<spirv/unified1/spirv.hpp>)
|
|
41
|
+
# include <spirv/unified1/spirv.hpp>
|
|
42
|
+
#elif __has_include(<spirv-headers/spirv.hpp>)
|
|
43
|
+
# include <spirv-headers/spirv.hpp>
|
|
44
|
+
#elif __has_include(<spirv.hpp>)
|
|
45
|
+
# include <spirv.hpp>
|
|
46
|
+
#else
|
|
47
|
+
// Fallback to let the compiler throw a standard "file not found" error
|
|
48
|
+
# include <spirv/unified1/spirv.hpp>
|
|
49
|
+
#endif
|
|
50
|
+
|
|
24
51
|
#include <algorithm>
|
|
25
52
|
#include <cmath>
|
|
26
53
|
#include <iomanip>
|
|
@@ -35,9 +62,10 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
|
|
35
62
|
#include <map>
|
|
36
63
|
#include <set>
|
|
37
64
|
#include <unordered_map>
|
|
38
|
-
#include <
|
|
65
|
+
#include <shared_mutex>
|
|
39
66
|
#include <mutex>
|
|
40
67
|
#include <future>
|
|
68
|
+
#include <condition_variable>
|
|
41
69
|
#include <thread>
|
|
42
70
|
|
|
43
71
|
#if defined(_MSC_VER)
|
|
@@ -85,6 +113,21 @@ typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
|
|
|
85
113
|
} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
|
|
86
114
|
#endif
|
|
87
115
|
|
|
116
|
+
#if !defined(VK_VALVE_shader_mixed_float_dot_product)
|
|
117
|
+
#define VK_VALVE_shader_mixed_float_dot_product 1
|
|
118
|
+
#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_SPEC_VERSION 1
|
|
119
|
+
#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_EXTENSION_NAME "VK_VALVE_shader_mixed_float_dot_product"
|
|
120
|
+
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE ((VkStructureType)1000673000)
|
|
121
|
+
typedef struct VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE {
|
|
122
|
+
VkStructureType sType;
|
|
123
|
+
void* pNext;
|
|
124
|
+
VkBool32 shaderMixedFloatDotProductFloat16AccFloat32;
|
|
125
|
+
VkBool32 shaderMixedFloatDotProductFloat16AccFloat16;
|
|
126
|
+
VkBool32 shaderMixedFloatDotProductBFloat16Acc;
|
|
127
|
+
VkBool32 shaderMixedFloatDotProductFloat8AccFloat32;
|
|
128
|
+
} VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE;
|
|
129
|
+
#endif
|
|
130
|
+
|
|
88
131
|
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
|
89
132
|
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
|
90
133
|
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
|
@@ -97,8 +140,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
|
|
97
140
|
|
|
98
141
|
#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
|
|
99
142
|
|
|
100
|
-
#define GGML_VK_MAX_NODES 8192
|
|
101
|
-
|
|
102
143
|
#define VK_CHECK(err, msg) \
|
|
103
144
|
do { \
|
|
104
145
|
vk::Result err_ = (err); \
|
|
@@ -134,8 +175,9 @@ struct vk_pipeline_struct {
|
|
|
134
175
|
uint32_t align;
|
|
135
176
|
// true if fields have been set by ggml_vk_create_pipeline
|
|
136
177
|
bool initialized {};
|
|
137
|
-
//
|
|
138
|
-
|
|
178
|
+
// true while a compile is in flight, used to dedupe concurrent claims.
|
|
179
|
+
// Protected by device->compile_mutex.
|
|
180
|
+
bool compile_pending {};
|
|
139
181
|
// set to true when the shader has been compiled
|
|
140
182
|
std::atomic<bool> compiled {};
|
|
141
183
|
// number of registers used, extracted from pipeline executable properties
|
|
@@ -191,6 +233,7 @@ struct vk_queue;
|
|
|
191
233
|
|
|
192
234
|
struct vk_command_buffer {
|
|
193
235
|
vk::CommandBuffer buf;
|
|
236
|
+
uint64_t use_counter = 0;
|
|
194
237
|
bool in_use = false;
|
|
195
238
|
};
|
|
196
239
|
|
|
@@ -386,6 +429,7 @@ enum vk_conv_shapes {
|
|
|
386
429
|
CONV_SHAPE_128x128,
|
|
387
430
|
CONV_SHAPE_64x32,
|
|
388
431
|
CONV_SHAPE_32x256,
|
|
432
|
+
CONV_SHAPE_64x128,
|
|
389
433
|
CONV_SHAPE_COUNT,
|
|
390
434
|
};
|
|
391
435
|
|
|
@@ -400,6 +444,7 @@ vk_conv_block_size vk_conv_block_sizes[CONV_SHAPE_COUNT] = {
|
|
|
400
444
|
{ 128, 128, 16 }, // CONV_SHAPE_128x128
|
|
401
445
|
{ 64, 32, 32 }, // CONV_SHAPE_64x32
|
|
402
446
|
{ 32, 256, 16 }, // CONV_SHAPE_32x256
|
|
447
|
+
{ 64, 128, 16 }, // CONV_SHAPE_64x128
|
|
403
448
|
};
|
|
404
449
|
|
|
405
450
|
enum dmmv_wg_sizes {
|
|
@@ -425,22 +470,26 @@ struct vk_fa_pipeline_state {
|
|
|
425
470
|
bool f32acc;
|
|
426
471
|
uint32_t flags;
|
|
427
472
|
uint32_t limit_occupancy_shmem;
|
|
473
|
+
ggml_type k_type;
|
|
474
|
+
ggml_type v_type;
|
|
428
475
|
|
|
429
476
|
bool operator<(const vk_fa_pipeline_state &b) const {
|
|
430
|
-
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
|
|
431
|
-
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
|
|
477
|
+
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem, k_type, v_type) <
|
|
478
|
+
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem, b.k_type, b.v_type);
|
|
432
479
|
}
|
|
433
480
|
};
|
|
434
481
|
|
|
435
482
|
struct vk_conv2d_pipeline_state {
|
|
436
|
-
vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH)
|
|
437
|
-
: s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {}
|
|
483
|
+
vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH, uint32_t aligned)
|
|
484
|
+
: s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH), aligned(aligned) {}
|
|
438
485
|
|
|
439
486
|
uint32_t s0, s1, p0, p1, d0, d1, KW, KH;
|
|
487
|
+
// when set, shader can skip K/CRS/NPQ bounds checks and address clamps
|
|
488
|
+
uint32_t aligned;
|
|
440
489
|
|
|
441
490
|
bool operator<(const vk_conv2d_pipeline_state &b) const {
|
|
442
|
-
return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) <
|
|
443
|
-
std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH);
|
|
491
|
+
return std::tie(s0, s1, p0, p1, d0, d1, KW, KH, aligned) <
|
|
492
|
+
std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH, b.aligned);
|
|
444
493
|
}
|
|
445
494
|
};
|
|
446
495
|
|
|
@@ -485,6 +534,12 @@ static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGM
|
|
|
485
534
|
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
|
486
535
|
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
|
487
536
|
|
|
537
|
+
// Snake activation: y = x + sin(a*x)^2 * inv_b. Used by the optimize_graph reorder
|
|
538
|
+
// pass so it keeps the chain contiguous and by the dispatcher to detect the fusion.
|
|
539
|
+
static constexpr std::initializer_list<ggml_op> snake_pattern { GGML_OP_MUL, GGML_OP_SIN,
|
|
540
|
+
GGML_OP_SQR, GGML_OP_MUL,
|
|
541
|
+
GGML_OP_ADD };
|
|
542
|
+
|
|
488
543
|
//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
|
|
489
544
|
//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
|
|
490
545
|
//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
|
|
@@ -581,6 +636,14 @@ static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_vie
|
|
|
581
636
|
|
|
582
637
|
struct vk_device_struct {
|
|
583
638
|
std::recursive_mutex mutex;
|
|
639
|
+
mutable std::shared_mutex pinned_memory_mutex;
|
|
640
|
+
|
|
641
|
+
// Guards compile_pending, all_pipelines, and the dynamic pipeline maps
|
|
642
|
+
// (flash_attn, fa_mask_opt, solve_tri, conv2d, etc). The actual compile
|
|
643
|
+
// runs with no lock held, so different pipelines can compile in parallel.
|
|
644
|
+
// Lock order is device->mutex -> compile_mutex, never the reverse.
|
|
645
|
+
std::mutex compile_mutex;
|
|
646
|
+
std::condition_variable compile_cv;
|
|
584
647
|
|
|
585
648
|
vk::PhysicalDevice physical_device;
|
|
586
649
|
vk::PhysicalDeviceProperties properties;
|
|
@@ -654,6 +717,10 @@ struct vk_device_struct {
|
|
|
654
717
|
uint32_t coopmat_int_k;
|
|
655
718
|
|
|
656
719
|
bool coopmat2;
|
|
720
|
+
bool coopmat2_bf16_support {};
|
|
721
|
+
bool coopmat2_decode_vector;
|
|
722
|
+
|
|
723
|
+
bool dot2_f16 {};
|
|
657
724
|
|
|
658
725
|
bool pipeline_executable_properties_support {};
|
|
659
726
|
|
|
@@ -666,6 +733,15 @@ struct vk_device_struct {
|
|
|
666
733
|
bool mul_mat_id_m[GGML_TYPE_COUNT];
|
|
667
734
|
bool mul_mat_id_s[GGML_TYPE_COUNT];
|
|
668
735
|
|
|
736
|
+
// Separate flags for the q8_1 (integer dot) mmq path, whose shader uses
|
|
737
|
+
// a different shared-memory layout than the float matmul shaders.
|
|
738
|
+
bool mul_mat_l_int[GGML_TYPE_COUNT];
|
|
739
|
+
bool mul_mat_m_int[GGML_TYPE_COUNT];
|
|
740
|
+
bool mul_mat_s_int[GGML_TYPE_COUNT];
|
|
741
|
+
bool mul_mat_id_l_int[GGML_TYPE_COUNT];
|
|
742
|
+
bool mul_mat_id_m_int[GGML_TYPE_COUNT];
|
|
743
|
+
bool mul_mat_id_s_int[GGML_TYPE_COUNT];
|
|
744
|
+
|
|
669
745
|
vk::DescriptorSetLayout dsl;
|
|
670
746
|
|
|
671
747
|
vk_matmul_pipeline pipeline_matmul_f32 {};
|
|
@@ -735,9 +811,10 @@ struct vk_device_struct {
|
|
|
735
811
|
vk_pipeline pipeline_clamp_f32;
|
|
736
812
|
vk_pipeline pipeline_pad_f32;
|
|
737
813
|
vk_pipeline pipeline_roll_f32;
|
|
738
|
-
vk_pipeline
|
|
739
|
-
vk_pipeline
|
|
740
|
-
vk_pipeline
|
|
814
|
+
vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32;
|
|
815
|
+
vk_pipeline pipeline_repeat_i16;
|
|
816
|
+
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
|
|
817
|
+
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
|
|
741
818
|
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
|
742
819
|
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
|
743
820
|
vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32;
|
|
@@ -784,6 +861,7 @@ struct vk_device_struct {
|
|
|
784
861
|
vk_pipeline pipeline_arange_f32;
|
|
785
862
|
|
|
786
863
|
vk_pipeline pipeline_fill_f32;
|
|
864
|
+
vk_pipeline pipeline_fill_f16;
|
|
787
865
|
|
|
788
866
|
vk_pipeline pipeline_geglu[2];
|
|
789
867
|
vk_pipeline pipeline_reglu[2];
|
|
@@ -811,6 +889,7 @@ struct vk_device_struct {
|
|
|
811
889
|
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
|
812
890
|
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
|
|
813
891
|
vk_pipeline pipeline_sum_rows_f32;
|
|
892
|
+
vk_pipeline pipeline_fwht_f32[4];
|
|
814
893
|
vk_pipeline pipeline_cumsum_f32;
|
|
815
894
|
vk_pipeline pipeline_cumsum_small_f32;
|
|
816
895
|
vk_pipeline pipeline_cumsum_multipass1_f32;
|
|
@@ -822,6 +901,9 @@ struct vk_device_struct {
|
|
|
822
901
|
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
|
|
823
902
|
vk_pipeline pipeline_timestep_embedding_f32;
|
|
824
903
|
vk_pipeline pipeline_conv_transpose_1d_f32;
|
|
904
|
+
vk_pipeline pipeline_snake_f32;
|
|
905
|
+
vk_pipeline pipeline_snake_f16;
|
|
906
|
+
vk_pipeline pipeline_snake_bf16;
|
|
825
907
|
vk_pipeline pipeline_pool2d_f32;
|
|
826
908
|
vk_pipeline pipeline_rwkv_wkv6_f32;
|
|
827
909
|
vk_pipeline pipeline_rwkv_wkv7_f32;
|
|
@@ -830,6 +912,8 @@ struct vk_device_struct {
|
|
|
830
912
|
vk_pipeline pipeline_ssm_scan_f32_d128;
|
|
831
913
|
vk_pipeline pipeline_ssm_scan_f32_d256;
|
|
832
914
|
vk_pipeline pipeline_ssm_conv_f32;
|
|
915
|
+
vk_pipeline pipeline_ssm_conv_silu_f32;
|
|
916
|
+
vk_pipeline pipeline_ssm_conv_bias_silu_f32;
|
|
833
917
|
vk_pipeline pipeline_opt_step_adamw_f32;
|
|
834
918
|
vk_pipeline pipeline_opt_step_sgd_f32;
|
|
835
919
|
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
|
@@ -839,7 +923,7 @@ struct vk_device_struct {
|
|
|
839
923
|
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
|
|
840
924
|
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
|
|
841
925
|
|
|
842
|
-
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16
|
|
926
|
+
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16;
|
|
843
927
|
|
|
844
928
|
std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;
|
|
845
929
|
|
|
@@ -938,19 +1022,24 @@ struct vk_subbuffer {
|
|
|
938
1022
|
}
|
|
939
1023
|
};
|
|
940
1024
|
|
|
941
|
-
|
|
942
|
-
|
|
1025
|
+
struct vk_semaphore {
|
|
1026
|
+
vk::Semaphore s;
|
|
1027
|
+
uint64_t value;
|
|
1028
|
+
};
|
|
1029
|
+
|
|
1030
|
+
// vk_event is used for the event-related backend interfaces. It uses vk::Events for
|
|
1031
|
+
// event_wait and a timeline semaphore for event_synchronize. Polling on an event for
|
|
943
1032
|
// event_synchronize wouldn't be sufficient to wait for command buffers to complete,
|
|
944
1033
|
// and would lead to validation errors.
|
|
945
1034
|
struct vk_event {
|
|
1035
|
+
std::vector<vk::Event> events_free; // Events available for reuse
|
|
1036
|
+
std::vector<vk::Event> events_submitted; // Events that are fully submitted and can be reused on next synchronize
|
|
946
1037
|
vk::Event event;
|
|
947
|
-
|
|
948
|
-
vk_command_buffer* cmd_buffer = nullptr;
|
|
949
|
-
};
|
|
1038
|
+
bool has_event;
|
|
950
1039
|
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
uint64_t
|
|
1040
|
+
vk_semaphore tl_semaphore;
|
|
1041
|
+
vk_command_buffer* cmd_buffer = nullptr;
|
|
1042
|
+
uint64_t cmd_buffer_use_counter = 0;
|
|
954
1043
|
};
|
|
955
1044
|
|
|
956
1045
|
struct vk_submission {
|
|
@@ -1091,6 +1180,13 @@ struct vk_op_push_constants {
|
|
|
1091
1180
|
float param4;
|
|
1092
1181
|
};
|
|
1093
1182
|
|
|
1183
|
+
struct vk_op_fwht_push_constants {
|
|
1184
|
+
uint32_t n_rows;
|
|
1185
|
+
uint32_t src_offset;
|
|
1186
|
+
uint32_t dst_offset;
|
|
1187
|
+
float scale;
|
|
1188
|
+
};
|
|
1189
|
+
|
|
1094
1190
|
struct vk_op_count_experts_push_constants {
|
|
1095
1191
|
uint32_t ne00;
|
|
1096
1192
|
uint32_t ne01;
|
|
@@ -1106,6 +1202,16 @@ struct vk_op_glu_push_constants {
|
|
|
1106
1202
|
uint32_t mode; // 0: default, 1: swapped, 2: split
|
|
1107
1203
|
float alpha; // for swiglu_oai
|
|
1108
1204
|
float limit;
|
|
1205
|
+
uint32_t nb01;
|
|
1206
|
+
uint32_t nb02;
|
|
1207
|
+
uint32_t nb03;
|
|
1208
|
+
uint32_t ne01;
|
|
1209
|
+
uint32_t ne02;
|
|
1210
|
+
uint32_t nb11;
|
|
1211
|
+
uint32_t nb12;
|
|
1212
|
+
uint32_t nb13;
|
|
1213
|
+
uint32_t ne11;
|
|
1214
|
+
uint32_t ne12;
|
|
1109
1215
|
};
|
|
1110
1216
|
|
|
1111
1217
|
struct vk_op_unary_push_constants {
|
|
@@ -1313,6 +1419,8 @@ struct vk_op_rope_push_constants {
|
|
|
1313
1419
|
uint32_t nb11;
|
|
1314
1420
|
uint32_t nb12;
|
|
1315
1421
|
uint32_t nb13;
|
|
1422
|
+
uint32_t a_offset;
|
|
1423
|
+
uint32_t d_offset;
|
|
1316
1424
|
};
|
|
1317
1425
|
static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
|
|
1318
1426
|
|
|
@@ -1371,7 +1479,7 @@ struct vk_op_im2col_push_constants {
|
|
|
1371
1479
|
uint32_t IW; uint32_t IH;
|
|
1372
1480
|
uint32_t OW; uint32_t OH;
|
|
1373
1481
|
uint32_t KW; uint32_t KH;
|
|
1374
|
-
uint32_t
|
|
1482
|
+
uint32_t OH_batch;
|
|
1375
1483
|
uint32_t CHW;
|
|
1376
1484
|
int32_t s0; int32_t s1;
|
|
1377
1485
|
int32_t p0; int32_t p1;
|
|
@@ -1432,6 +1540,11 @@ struct vk_op_conv_transpose_1d_push_constants {
|
|
|
1432
1540
|
int32_t s0;
|
|
1433
1541
|
};
|
|
1434
1542
|
|
|
1543
|
+
struct vk_op_snake_push_constants {
|
|
1544
|
+
uint32_t ne0;
|
|
1545
|
+
uint32_t ne1;
|
|
1546
|
+
};
|
|
1547
|
+
|
|
1435
1548
|
struct vk_op_pool2d_push_constants {
|
|
1436
1549
|
uint32_t IW; uint32_t IH;
|
|
1437
1550
|
uint32_t OW; uint32_t OH;
|
|
@@ -1466,6 +1579,7 @@ struct vk_op_gated_delta_net_push_constants {
|
|
|
1466
1579
|
uint32_t sb1, sb2, sb3;
|
|
1467
1580
|
uint32_t neq1, rq3;
|
|
1468
1581
|
float scale;
|
|
1582
|
+
uint32_t K;
|
|
1469
1583
|
};
|
|
1470
1584
|
|
|
1471
1585
|
struct vk_op_ssm_scan_push_constants {
|
|
@@ -1641,7 +1755,7 @@ struct ggml_vk_garbage_collector {
|
|
|
1641
1755
|
};
|
|
1642
1756
|
|
|
1643
1757
|
static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx);
|
|
1644
|
-
static void ggml_vk_load_shaders(vk_device& device);
|
|
1758
|
+
static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested = nullptr);
|
|
1645
1759
|
static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);
|
|
1646
1760
|
|
|
1647
1761
|
static bool vk_memory_logger_enabled = false;
|
|
@@ -1879,6 +1993,9 @@ struct ggml_backend_vk_context {
|
|
|
1879
1993
|
// Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
|
|
1880
1994
|
vk_pipeline_struct * prealloc_y_last_pipeline_used {};
|
|
1881
1995
|
const ggml_tensor * prealloc_y_last_tensor_used {};
|
|
1996
|
+
// True when prealloc_y holds the padded fp16 layout used by the coopmat2 B decode-vector callback.
|
|
1997
|
+
// If false, then it's contiguous.
|
|
1998
|
+
bool prealloc_y_last_decode_vector_staging {};
|
|
1882
1999
|
|
|
1883
2000
|
// Track which nodes have been used since the last sync, and whether they were written to
|
|
1884
2001
|
std::vector<const ggml_tensor *> unsynced_nodes_written;
|
|
@@ -1978,6 +2095,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
|
|
|
1978
2095
|
GGML_UNUSED(src3);
|
|
1979
2096
|
}
|
|
1980
2097
|
|
|
2098
|
+
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_fwht_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
|
|
2099
|
+
p.src_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
|
2100
|
+
p.dst_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
|
2101
|
+
|
|
2102
|
+
GGML_UNUSED(src1);
|
|
2103
|
+
GGML_UNUSED(src2);
|
|
2104
|
+
GGML_UNUSED(src3);
|
|
2105
|
+
}
|
|
2106
|
+
|
|
1981
2107
|
struct ggml_backend_vk_buffer_context {
|
|
1982
2108
|
vk_device_ref device;
|
|
1983
2109
|
vk_buffer dev_buffer;
|
|
@@ -2018,9 +2144,9 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
|
|
|
2018
2144
|
const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
2019
2145
|
std::string type = device ? "device" : "host";
|
|
2020
2146
|
auto it = allocations.find(buf->buffer);
|
|
2021
|
-
total_device -= device ? it->second : 0;
|
|
2022
|
-
total_host -= device ? 0 : it->second;
|
|
2023
2147
|
if (it != allocations.end()) {
|
|
2148
|
+
total_device -= device ? it->second : 0;
|
|
2149
|
+
total_host -= device ? 0 : it->second;
|
|
2024
2150
|
VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
|
|
2025
2151
|
allocations.erase(it);
|
|
2026
2152
|
} else {
|
|
@@ -2099,10 +2225,135 @@ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
|
|
|
2099
2225
|
ctx->device->device.resetFences({ ctx->fence });
|
|
2100
2226
|
}
|
|
2101
2227
|
|
|
2102
|
-
|
|
2103
|
-
static uint32_t
|
|
2104
|
-
static
|
|
2105
|
-
|
|
2228
|
+
static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367;
|
|
2229
|
+
static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447;
|
|
2230
|
+
static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4;
|
|
2231
|
+
|
|
2232
|
+
// Remove SPV_NV_cooperative_matrix_decode_vector usage from a SPIR-V module so it
|
|
2233
|
+
// can be loaded on drivers that only support SPV_NV_cooperative_matrix2. Drops the
|
|
2234
|
+
// OpExtension declaration, the CooperativeMatrixDecodeVectorNV OpCapability, and the
|
|
2235
|
+
// DecodeVectorFunc operand from any OpCooperativeMatrixLoadTensorNV instruction.
|
|
2236
|
+
// Returns true when the input used the extension (and `out` was populated with a
|
|
2237
|
+
// stripped copy); returns false otherwise without touching `out`.
|
|
2238
|
+
static bool ggml_vk_strip_decode_vector(const uint32_t * code, size_t word_count, std::vector<uint32_t> & out) {
|
|
2239
|
+
static const char kDecodeVectorExt[] = "SPV_NV_cooperative_matrix_decode_vector";
|
|
2240
|
+
|
|
2241
|
+
if (word_count < 5) {
|
|
2242
|
+
return false;
|
|
2243
|
+
}
|
|
2244
|
+
|
|
2245
|
+
bool uses_decode_vector = false;
|
|
2246
|
+
for (size_t pos = 5; pos < word_count; ) {
|
|
2247
|
+
uint32_t word = code[pos];
|
|
2248
|
+
uint32_t wc = word >> spv::WordCountShift;
|
|
2249
|
+
uint32_t op = word & spv::OpCodeMask;
|
|
2250
|
+
GGML_ASSERT(wc > 0 && pos + wc <= word_count);
|
|
2251
|
+
if (op == spv::OpExtension && wc >= 2) {
|
|
2252
|
+
const char * s = reinterpret_cast<const char *>(&code[pos + 1]);
|
|
2253
|
+
if (strcmp(s, kDecodeVectorExt) == 0) {
|
|
2254
|
+
uses_decode_vector = true;
|
|
2255
|
+
break;
|
|
2256
|
+
}
|
|
2257
|
+
}
|
|
2258
|
+
pos += wc;
|
|
2259
|
+
}
|
|
2260
|
+
|
|
2261
|
+
if (!uses_decode_vector) {
|
|
2262
|
+
return false;
|
|
2263
|
+
}
|
|
2264
|
+
|
|
2265
|
+
VK_LOG_DEBUG("ggml_vk_strip_decode_vector: stripping SPV_NV_cooperative_matrix_decode_vector");
|
|
2266
|
+
|
|
2267
|
+
// Bulk-copy unchanged runs and only break the run when an instruction needs to
|
|
2268
|
+
// be dropped or patched. Use reserve + insert/push_back so the destination buffer
|
|
2269
|
+
// is touched exactly once (no zero-initialization pass from resize()).
|
|
2270
|
+
out.clear();
|
|
2271
|
+
out.reserve(word_count);
|
|
2272
|
+
|
|
2273
|
+
size_t run_start = 0;
|
|
2274
|
+
auto flush_run = [&](size_t up_to) {
|
|
2275
|
+
if (up_to > run_start) {
|
|
2276
|
+
out.insert(out.end(), code + run_start, code + up_to);
|
|
2277
|
+
}
|
|
2278
|
+
};
|
|
2279
|
+
|
|
2280
|
+
for (size_t pos = 5; pos < word_count; ) {
|
|
2281
|
+
uint32_t word = code[pos];
|
|
2282
|
+
uint32_t wc = word >> spv::WordCountShift;
|
|
2283
|
+
uint32_t op = word & spv::OpCodeMask;
|
|
2284
|
+
GGML_ASSERT(wc > 0 && pos + wc <= word_count);
|
|
2285
|
+
|
|
2286
|
+
if (op == spv::OpExtension && wc >= 2) {
|
|
2287
|
+
const char * s = reinterpret_cast<const char *>(&code[pos + 1]);
|
|
2288
|
+
if (strcmp(s, kDecodeVectorExt) == 0) {
|
|
2289
|
+
flush_run(pos);
|
|
2290
|
+
pos += wc;
|
|
2291
|
+
run_start = pos;
|
|
2292
|
+
continue;
|
|
2293
|
+
}
|
|
2294
|
+
}
|
|
2295
|
+
|
|
2296
|
+
if (op == spv::OpCapability && wc == 2 && code[pos + 1] == kSpvCapabilityCooperativeMatrixDecodeVectorNV) {
|
|
2297
|
+
flush_run(pos);
|
|
2298
|
+
pos += wc;
|
|
2299
|
+
run_start = pos;
|
|
2300
|
+
continue;
|
|
2301
|
+
}
|
|
2302
|
+
|
|
2303
|
+
if (op == kSpvOpCooperativeMatrixLoadTensorNV) {
|
|
2304
|
+
// [opcode/wc][ResultType][Result][Pointer][Object][TensorLayout][MemOperand mask][mem extras...][TA mask][ta extras...]
|
|
2305
|
+
GGML_ASSERT(wc >= 8);
|
|
2306
|
+
|
|
2307
|
+
uint32_t mem_mask = code[pos + 6];
|
|
2308
|
+
size_t cur = pos + 7;
|
|
2309
|
+
// Each of these MemoryAccess bits (when set) carries one trailing operand.
|
|
2310
|
+
cur += (mem_mask & 0x2) ? 1 : 0; // Aligned
|
|
2311
|
+
cur += (mem_mask & 0x8) ? 1 : 0; // MakePointerAvailable
|
|
2312
|
+
cur += (mem_mask & 0x10) ? 1 : 0; // MakePointerVisible
|
|
2313
|
+
cur += (mem_mask & 0x10000) ? 1 : 0; // AliasScopeINTELMask
|
|
2314
|
+
cur += (mem_mask & 0x20000) ? 1 : 0; // NoAliasINTELMask
|
|
2315
|
+
GGML_ASSERT(cur < pos + wc);
|
|
2316
|
+
|
|
2317
|
+
uint32_t ta_mask = code[cur];
|
|
2318
|
+
if ((ta_mask & kSpvTensorAddressingDecodeVectorFuncBit) == 0) {
|
|
2319
|
+
pos += wc;
|
|
2320
|
+
continue; // leave instruction inside the current unchanged run
|
|
2321
|
+
}
|
|
2322
|
+
|
|
2323
|
+
flush_run(pos);
|
|
2324
|
+
|
|
2325
|
+
// Append unchanged prefix of the instruction (header through the mem-extras).
|
|
2326
|
+
size_t inst_start = out.size();
|
|
2327
|
+
size_t pre_n = cur - pos;
|
|
2328
|
+
out.insert(out.end(), code + pos, code + pos + pre_n);
|
|
2329
|
+
|
|
2330
|
+
// Emit TA mask with the DecodeVectorFunc bit cleared.
|
|
2331
|
+
out.push_back(ta_mask & ~kSpvTensorAddressingDecodeVectorFuncBit);
|
|
2332
|
+
|
|
2333
|
+
// TA extras: TensorView (0x1) and DecodeFunc (0x2) are kept verbatim;
|
|
2334
|
+
// DecodeVectorFunc (0x4) is dropped along with its trailing id operand.
|
|
2335
|
+
size_t keep_ta_extras = ((ta_mask & 0x1) ? 1 : 0) + ((ta_mask & 0x2) ? 1 : 0);
|
|
2336
|
+
if (keep_ta_extras) {
|
|
2337
|
+
out.insert(out.end(), code + cur + 1, code + cur + 1 + keep_ta_extras);
|
|
2338
|
+
}
|
|
2339
|
+
|
|
2340
|
+
GGML_ASSERT(wc == pre_n + 1 + keep_ta_extras + 1);
|
|
2341
|
+
|
|
2342
|
+
// Patch the instruction header with the new (one-shorter) word count.
|
|
2343
|
+
uint32_t new_wc = wc - 1;
|
|
2344
|
+
out[inst_start] = (new_wc << spv::WordCountShift) | op;
|
|
2345
|
+
|
|
2346
|
+
pos += wc;
|
|
2347
|
+
run_start = pos;
|
|
2348
|
+
continue;
|
|
2349
|
+
}
|
|
2350
|
+
|
|
2351
|
+
pos += wc;
|
|
2352
|
+
}
|
|
2353
|
+
|
|
2354
|
+
flush_run(word_count);
|
|
2355
|
+
return true;
|
|
2356
|
+
}
|
|
2106
2357
|
|
|
2107
2358
|
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
|
|
2108
2359
|
uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
|
|
@@ -2115,6 +2366,78 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
|
|
2115
2366
|
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
|
|
2116
2367
|
|
|
2117
2368
|
vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
|
|
2369
|
+
|
|
2370
|
+
// Patch SPIR-V to enable RTE rounding for FP16, avoiding the need for
|
|
2371
|
+
// separate shader variants compiled with -DRTE16.
|
|
2372
|
+
std::vector<uint32_t> spirv;
|
|
2373
|
+
if (device->float_controls_rte_fp16) {
|
|
2374
|
+
const uint32_t* spv_words = reinterpret_cast<const uint32_t *>(spv_data);
|
|
2375
|
+
size_t word_count = spv_size / sizeof(uint32_t);
|
|
2376
|
+
spirv.assign(spv_words, spv_words + word_count);
|
|
2377
|
+
|
|
2378
|
+
// Find insertion points respecting SPIR-V layout order:
|
|
2379
|
+
// Header(5) -> OpCapability -> OpExtension -> ... -> OpEntryPoint -> OpExecutionMode -> ...
|
|
2380
|
+
size_t pos = 5; // skip header
|
|
2381
|
+
size_t cap_insert_pos = pos;
|
|
2382
|
+
size_t ext_insert_pos = pos;
|
|
2383
|
+
size_t exec_insert_pos = pos;
|
|
2384
|
+
uint32_t entry_point_id = 0;
|
|
2385
|
+
|
|
2386
|
+
while (pos < spirv.size()) {
|
|
2387
|
+
uint32_t opcode = spirv[pos] & spv::OpCodeMask;
|
|
2388
|
+
uint32_t len = spirv[pos] >> spv::WordCountShift;
|
|
2389
|
+
if (len == 0) break;
|
|
2390
|
+
|
|
2391
|
+
if (opcode == spv::OpCapability) {
|
|
2392
|
+
cap_insert_pos = pos + len;
|
|
2393
|
+
ext_insert_pos = pos + len;
|
|
2394
|
+
} else if (opcode == spv::OpExtension) {
|
|
2395
|
+
ext_insert_pos = pos + len;
|
|
2396
|
+
} else if (opcode == spv::OpEntryPoint) {
|
|
2397
|
+
entry_point_id = spirv[pos + 2];
|
|
2398
|
+
exec_insert_pos = pos + len;
|
|
2399
|
+
} else if (opcode == spv::OpExecutionMode || opcode == spv::OpExecutionModeId) {
|
|
2400
|
+
exec_insert_pos = pos + len;
|
|
2401
|
+
} else if (entry_point_id != 0) {
|
|
2402
|
+
break;
|
|
2403
|
+
}
|
|
2404
|
+
|
|
2405
|
+
pos += len;
|
|
2406
|
+
}
|
|
2407
|
+
|
|
2408
|
+
// Insert from latest position first so earlier indices stay valid.
|
|
2409
|
+
|
|
2410
|
+
// OpExecutionMode %entrypoint RoundingModeRTE 16
|
|
2411
|
+
uint32_t exec_mode[] = { (4u << spv::WordCountShift) | spv::OpExecutionMode, entry_point_id, spv::ExecutionModeRoundingModeRTE, 16 };
|
|
2412
|
+
spirv.insert(spirv.begin() + exec_insert_pos, std::begin(exec_mode), std::end(exec_mode));
|
|
2413
|
+
|
|
2414
|
+
// OpExtension "SPV_KHR_float_controls"
|
|
2415
|
+
const char ext_str[] = "SPV_KHR_float_controls";
|
|
2416
|
+
size_t ext_str_words = CEIL_DIV(sizeof(ext_str), sizeof(uint32_t));
|
|
2417
|
+
std::vector<uint32_t> extension(1 + ext_str_words, 0);
|
|
2418
|
+
extension[0] = (uint32_t)((1 + ext_str_words) << spv::WordCountShift) | spv::OpExtension;
|
|
2419
|
+
memcpy(&extension[1], ext_str, sizeof(ext_str));
|
|
2420
|
+
spirv.insert(spirv.begin() + ext_insert_pos, extension.begin(), extension.end());
|
|
2421
|
+
|
|
2422
|
+
// OpCapability RoundingModeRTE
|
|
2423
|
+
uint32_t capability[] = { (2u << spv::WordCountShift) | spv::OpCapability, spv::CapabilityRoundingModeRTE };
|
|
2424
|
+
spirv.insert(spirv.begin() + cap_insert_pos, std::begin(capability), std::end(capability));
|
|
2425
|
+
|
|
2426
|
+
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data());
|
|
2427
|
+
}
|
|
2428
|
+
|
|
2429
|
+
#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
|
|
2430
|
+
if (device->coopmat2 && !device->coopmat2_decode_vector) {
|
|
2431
|
+
const uint32_t * src = spirv.empty() ? reinterpret_cast<const uint32_t *>(spv_data) : spirv.data();
|
|
2432
|
+
size_t src_n = spirv.empty() ? spv_size / sizeof(uint32_t) : spirv.size();
|
|
2433
|
+
std::vector<uint32_t> stripped;
|
|
2434
|
+
if (ggml_vk_strip_decode_vector(src, src_n, stripped)) {
|
|
2435
|
+
spirv = std::move(stripped);
|
|
2436
|
+
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data());
|
|
2437
|
+
}
|
|
2438
|
+
}
|
|
2439
|
+
#endif
|
|
2440
|
+
|
|
2118
2441
|
pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
|
|
2119
2442
|
|
|
2120
2443
|
vk::PushConstantRange pcr(
|
|
@@ -2196,7 +2519,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
|
|
2196
2519
|
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
|
|
2197
2520
|
throw e;
|
|
2198
2521
|
}
|
|
2199
|
-
pipeline->compiled = true;
|
|
2200
2522
|
|
|
2201
2523
|
if (vk_instance.debug_utils_support) {
|
|
2202
2524
|
vk::DebugUtilsObjectNameInfoEXT duoni;
|
|
@@ -2245,14 +2567,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
|
|
2245
2567
|
}
|
|
2246
2568
|
}
|
|
2247
2569
|
|
|
2248
|
-
device->all_pipelines.push_back(pipeline);
|
|
2249
|
-
|
|
2250
2570
|
{
|
|
2251
|
-
std::lock_guard<std::mutex> guard(
|
|
2252
|
-
|
|
2253
|
-
|
|
2571
|
+
std::lock_guard<std::mutex> guard(device->compile_mutex);
|
|
2572
|
+
device->all_pipelines.push_back(pipeline);
|
|
2573
|
+
pipeline->compiled = true;
|
|
2574
|
+
pipeline->compile_pending = false;
|
|
2254
2575
|
}
|
|
2255
|
-
|
|
2576
|
+
device->compile_cv.notify_all();
|
|
2256
2577
|
}
|
|
2257
2578
|
|
|
2258
2579
|
static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
|
|
@@ -2268,8 +2589,7 @@ static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx,
|
|
|
2268
2589
|
VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
|
|
2269
2590
|
ctx->pipeline_descriptor_set_requirements += n;
|
|
2270
2591
|
if (!pipeline->compiled) {
|
|
2271
|
-
|
|
2272
|
-
ggml_vk_load_shaders(ctx->device);
|
|
2592
|
+
ggml_vk_load_shaders(ctx->device, pipeline);
|
|
2273
2593
|
}
|
|
2274
2594
|
ggml_pipeline_allocate_descriptor_sets(ctx);
|
|
2275
2595
|
}
|
|
@@ -2319,7 +2639,7 @@ static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_comman
|
|
|
2319
2639
|
vk::CommandBufferLevel::ePrimary,
|
|
2320
2640
|
1);
|
|
2321
2641
|
const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
|
|
2322
|
-
p.cmd_buffers.push_back({ cmd_buffers.front(), true });
|
|
2642
|
+
p.cmd_buffers.push_back({ cmd_buffers.front(), 0, true });
|
|
2323
2643
|
return &p.cmd_buffers[p.cmd_buffers.size()-1];
|
|
2324
2644
|
}
|
|
2325
2645
|
|
|
@@ -2788,6 +3108,15 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
|
|
|
2788
3108
|
);
|
|
2789
3109
|
}
|
|
2790
3110
|
|
|
3111
|
+
static void ggml_vk_reset_event(vk_context& ctx, vk::Event& event) {
|
|
3112
|
+
VK_LOG_DEBUG("ggml_vk_set_event()");
|
|
3113
|
+
|
|
3114
|
+
ctx->s->buffer->buf.resetEvent(
|
|
3115
|
+
event,
|
|
3116
|
+
ctx->p->q->stage_flags
|
|
3117
|
+
);
|
|
3118
|
+
}
|
|
3119
|
+
|
|
2791
3120
|
static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
|
|
2792
3121
|
VK_LOG_DEBUG("ggml_vk_set_event()");
|
|
2793
3122
|
|
|
@@ -2833,11 +3162,10 @@ struct vk_fa_tuning_params {
|
|
|
2833
3162
|
}
|
|
2834
3163
|
};
|
|
2835
3164
|
|
|
2836
|
-
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
|
2837
|
-
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
|
3165
|
+
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type);
|
|
3166
|
+
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type = GGML_TYPE_F16);
|
|
2838
3167
|
|
|
2839
|
-
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type
|
|
2840
|
-
GGML_UNUSED(kv_type);
|
|
3168
|
+
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
|
2841
3169
|
|
|
2842
3170
|
vk_fa_tuning_params result{};
|
|
2843
3171
|
result.path = FA_SCALAR;
|
|
@@ -2889,7 +3217,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
|
|
|
2889
3217
|
|
|
2890
3218
|
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
|
|
2891
3219
|
|
|
2892
|
-
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
|
|
3220
|
+
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, k_type, v_type)) {
|
|
2893
3221
|
result.block_rows /= 2;
|
|
2894
3222
|
}
|
|
2895
3223
|
|
|
@@ -2912,10 +3240,11 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
|
|
|
2912
3240
|
return result;
|
|
2913
3241
|
}
|
|
2914
3242
|
|
|
2915
|
-
static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type
|
|
3243
|
+
static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
|
2916
3244
|
GGML_UNUSED(n_rows);
|
|
2917
3245
|
GGML_UNUSED(n_kv);
|
|
2918
|
-
GGML_UNUSED(
|
|
3246
|
+
GGML_UNUSED(k_type);
|
|
3247
|
+
GGML_UNUSED(v_type);
|
|
2919
3248
|
GGML_UNUSED(f32acc);
|
|
2920
3249
|
|
|
2921
3250
|
vk_fa_tuning_params result{};
|
|
@@ -2942,7 +3271,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device
|
|
|
2942
3271
|
return result;
|
|
2943
3272
|
}
|
|
2944
3273
|
|
|
2945
|
-
static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type
|
|
3274
|
+
static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
|
2946
3275
|
GGML_UNUSED(n_kv);
|
|
2947
3276
|
GGML_UNUSED(f32acc);
|
|
2948
3277
|
|
|
@@ -2956,7 +3285,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
|
|
|
2956
3285
|
if (small_rows) {
|
|
2957
3286
|
result.block_rows = 32;
|
|
2958
3287
|
result.block_cols = 32;
|
|
2959
|
-
} else if (ggml_is_quantized(
|
|
3288
|
+
} else if (ggml_is_quantized(k_type) || ggml_is_quantized(v_type) || hsk >= 256 || hsv >= 256) {
|
|
2960
3289
|
result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
|
|
2961
3290
|
result.block_cols = 32;
|
|
2962
3291
|
} else {
|
|
@@ -2970,10 +3299,17 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
|
|
|
2970
3299
|
return result;
|
|
2971
3300
|
}
|
|
2972
3301
|
|
|
2973
|
-
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type
|
|
3302
|
+
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
|
|
2974
3303
|
FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
|
|
2975
3304
|
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
|
2976
3305
|
|
|
3306
|
+
if (path == FA_COOPMAT2 && k_type == GGML_TYPE_BF16 && !device->coopmat2_bf16_support) {
|
|
3307
|
+
path = FA_COOPMAT1;
|
|
3308
|
+
}
|
|
3309
|
+
if (path == FA_COOPMAT1 && k_type == GGML_TYPE_BF16 && !device->coopmat_bf16_support) {
|
|
3310
|
+
path = FA_SCALAR;
|
|
3311
|
+
}
|
|
3312
|
+
|
|
2977
3313
|
if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) {
|
|
2978
3314
|
// Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
|
|
2979
3315
|
path = FA_SCALAR;
|
|
@@ -2982,8 +3318,8 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
|
|
|
2982
3318
|
if (path == FA_COOPMAT1) {
|
|
2983
3319
|
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
|
|
2984
3320
|
(!f32acc && device->coopmat_support_16x16x16_f16acc);
|
|
2985
|
-
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv,
|
|
2986
|
-
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
|
|
3321
|
+
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
|
3322
|
+
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc, k_type);
|
|
2987
3323
|
|
|
2988
3324
|
if (!shape_ok || !shmem_ok) {
|
|
2989
3325
|
path = FA_SCALAR;
|
|
@@ -2995,20 +3331,25 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
|
|
|
2995
3331
|
path = FA_SCALAR;
|
|
2996
3332
|
}
|
|
2997
3333
|
|
|
3334
|
+
// Q1_0 K/V is only implemented on coopmat2 (flash_attn_cm2); there is no scalar FA shader for it.
|
|
3335
|
+
if ((k_type == GGML_TYPE_Q1_0 || v_type == GGML_TYPE_Q1_0) && device->coopmat2) {
|
|
3336
|
+
path = FA_COOPMAT2;
|
|
3337
|
+
}
|
|
3338
|
+
|
|
2998
3339
|
switch (path) {
|
|
2999
3340
|
case FA_SCALAR:
|
|
3000
|
-
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv,
|
|
3341
|
+
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
|
3001
3342
|
case FA_COOPMAT1:
|
|
3002
|
-
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv,
|
|
3343
|
+
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
|
3003
3344
|
case FA_COOPMAT2:
|
|
3004
|
-
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv,
|
|
3345
|
+
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
|
|
3005
3346
|
default:
|
|
3006
3347
|
throw std::runtime_error("unsupported FaCodePath");
|
|
3007
3348
|
}
|
|
3008
3349
|
}
|
|
3009
3350
|
|
|
3010
3351
|
static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
|
|
3011
|
-
bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
|
|
3352
|
+
bool use_mask, bool use_mask_opt, bool use_logit_softcap, ggml_type k_type, ggml_type v_type) {
|
|
3012
3353
|
const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&
|
|
3013
3354
|
(device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);
|
|
3014
3355
|
|
|
@@ -3019,12 +3360,32 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const
|
|
|
3019
3360
|
|
|
3020
3361
|
const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
|
|
3021
3362
|
|
|
3022
|
-
return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
|
|
3363
|
+
return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem, k_type, v_type};
|
|
3023
3364
|
}
|
|
3024
3365
|
|
|
3025
3366
|
static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {
|
|
3026
|
-
|
|
3027
|
-
|
|
3367
|
+
const auto fa_block_bytes = [](ggml_type t) -> uint32_t {
|
|
3368
|
+
if (t == GGML_TYPE_F32) return 16u;
|
|
3369
|
+
return (uint32_t) ggml_type_size(t);
|
|
3370
|
+
};
|
|
3371
|
+
return {
|
|
3372
|
+
/* 0 WorkGroupSize */ state.workgroup_size,
|
|
3373
|
+
/* 1 Br */ state.Br,
|
|
3374
|
+
/* 2 Bc */ state.Bc,
|
|
3375
|
+
/* 3 HSK */ state.HSK,
|
|
3376
|
+
/* 4 HSV */ state.HSV,
|
|
3377
|
+
/* 5 Clamp */ static_cast<uint32_t>(!state.aligned),
|
|
3378
|
+
/* 6 D_split */ state.D_split,
|
|
3379
|
+
/* 7 row_split */ state.row_split,
|
|
3380
|
+
/* 8 SubGroupSize */ state.subgroup_size,
|
|
3381
|
+
/* 9 SHMEM_STAGING */ state.shmem_staging ? 1u : 0u,
|
|
3382
|
+
/*10 Flags */ state.flags,
|
|
3383
|
+
/*11 LIMIT_OCCUPANCY_SHMEM */ state.limit_occupancy_shmem,
|
|
3384
|
+
/*12 FaTypeK */ static_cast<uint32_t>(state.k_type),
|
|
3385
|
+
/*13 FaTypeV */ static_cast<uint32_t>(state.v_type),
|
|
3386
|
+
/*14 FaBlockBytesK */ fa_block_bytes(state.k_type),
|
|
3387
|
+
/*15 FaBlockBytesV */ fa_block_bytes(state.v_type),
|
|
3388
|
+
};
|
|
3028
3389
|
}
|
|
3029
3390
|
|
|
3030
3391
|
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
|
@@ -3033,7 +3394,9 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
3033
3394
|
switch (src0_type) {
|
|
3034
3395
|
case GGML_TYPE_IQ1_S:
|
|
3035
3396
|
case GGML_TYPE_IQ1_M:
|
|
3036
|
-
|
|
3397
|
+
// Regular matmul uses the compact uint16_t IQ1 grid; the expanded
|
|
3398
|
+
// uint32_t grid is only enabled for the q8_1/int-dot vector path.
|
|
3399
|
+
lut_size = 2*2048;
|
|
3037
3400
|
break;
|
|
3038
3401
|
case GGML_TYPE_IQ2_XXS:
|
|
3039
3402
|
lut_size = 8*256;
|
|
@@ -3055,6 +3418,10 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
3055
3418
|
case GGML_TYPE_MXFP4:
|
|
3056
3419
|
lut_size = 4*16;
|
|
3057
3420
|
break;
|
|
3421
|
+
case GGML_TYPE_NVFP4:
|
|
3422
|
+
// Same kvalues budget as MXFP4 plus ue4m3_fp32_lut[128] (types.glsl, DATA_A_NVFP4).
|
|
3423
|
+
lut_size = 4*16 + 128u * (uint32_t)sizeof(float);
|
|
3424
|
+
break;
|
|
3058
3425
|
default:
|
|
3059
3426
|
break;
|
|
3060
3427
|
}
|
|
@@ -3078,6 +3445,70 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
3078
3445
|
return supported;
|
|
3079
3446
|
}
|
|
3080
3447
|
|
|
3448
|
+
// Shmem usage for the q8_1 mmq shader (mul_mmq.comp), which uses
|
|
3449
|
+
// block_a_cache / block_b_cache layouts (see mul_mmq_shmem_types.glsl) rather
|
|
3450
|
+
// than the float load buffers checked by ggml_vk_matmul_shmem_support.
|
|
3451
|
+
// Sizes follow std430 rules. Returns false for types without a q8_1 pipeline.
|
|
3452
|
+
static bool ggml_vk_matmul_int_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
|
3453
|
+
|
|
3454
|
+
// FLOAT_TYPE in the shader is float16_t with fp16 support, otherwise float.
|
|
3455
|
+
const uint32_t fp_size = device->fp16 ? 2u : 4u;
|
|
3456
|
+
const uint32_t fp_align = fp_size;
|
|
3457
|
+
const uint32_t fp2_size = 2u * fp_size;
|
|
3458
|
+
const uint32_t fp2_align = device->fp16 ? 4u : 8u;
|
|
3459
|
+
|
|
3460
|
+
struct member { uint32_t size, align; };
|
|
3461
|
+
auto std430_size = [](std::initializer_list<member> members) {
|
|
3462
|
+
uint32_t off = 0, struct_align = 1;
|
|
3463
|
+
for (const auto &m : members) {
|
|
3464
|
+
off = (off + m.align - 1) & ~(m.align - 1);
|
|
3465
|
+
off += m.size;
|
|
3466
|
+
struct_align = std::max(struct_align, m.align);
|
|
3467
|
+
}
|
|
3468
|
+
return (off + struct_align - 1) & ~(struct_align - 1);
|
|
3469
|
+
};
|
|
3470
|
+
|
|
3471
|
+
uint32_t block_a_size = 0;
|
|
3472
|
+
switch (src0_type) {
|
|
3473
|
+
case GGML_TYPE_Q4_0: block_a_size = std430_size({{16, 4}, {fp_size, fp_align}}); break; // qs[16/4] + dm
|
|
3474
|
+
case GGML_TYPE_Q4_1: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + dm(vec2)
|
|
3475
|
+
case GGML_TYPE_Q5_0: block_a_size = std430_size({{16, 4}, {4, 4}, {fp_size, fp_align}}); break; // qs[16/4] + qh + dm
|
|
3476
|
+
case GGML_TYPE_Q5_1: block_a_size = std430_size({{16, 4}, {4, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + qh + dm(vec2)
|
|
3477
|
+
case GGML_TYPE_Q8_0: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + dm
|
|
3478
|
+
case GGML_TYPE_MXFP4: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + d
|
|
3479
|
+
case GGML_TYPE_Q2_K: block_a_size = std430_size({{ 8, 4}, {2, 2}, {fp2_size, fp2_align}}); break; // qs[2] + scales(u8vec2) + dm(vec2)
|
|
3480
|
+
case GGML_TYPE_Q3_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + d_scales(vec2)
|
|
3481
|
+
case GGML_TYPE_Q4_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + dm(vec2)
|
|
3482
|
+
case GGML_TYPE_Q5_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + dm(vec2)
|
|
3483
|
+
case GGML_TYPE_Q6_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + d_scales(vec2)
|
|
3484
|
+
default:
|
|
3485
|
+
return false;
|
|
3486
|
+
}
|
|
3487
|
+
|
|
3488
|
+
// block_b_cache: { int32_t qs[8]; FLOAT_TYPEV2 ds; }
|
|
3489
|
+
const uint32_t block_b_size = std430_size({{32, 4}, {fp2_size, fp2_align}});
|
|
3490
|
+
|
|
3491
|
+
const uint32_t BM = warptile[1];
|
|
3492
|
+
const uint32_t BN = warptile[2];
|
|
3493
|
+
// mul_mmq.comp: BK_STEP=1 for MUL_MAT_ID, 4 otherwise.
|
|
3494
|
+
const uint32_t BK_STEP = mul_mat_id ? 1u : 4u;
|
|
3495
|
+
|
|
3496
|
+
const uint32_t buf_a_size = BM * BK_STEP * block_a_size;
|
|
3497
|
+
const uint32_t buf_b_size = BN * BK_STEP * block_b_size;
|
|
3498
|
+
const uint32_t mmid_row_ids = mul_mat_id ? (BN * 2u * (uint32_t)sizeof(uint16_t)) : 0u;
|
|
3499
|
+
|
|
3500
|
+
const uint32_t warps = warptile[0] / warptile[10];
|
|
3501
|
+
const uint32_t ballots_sh = mul_mat_id ? (warps * 4u * (uint32_t)sizeof(uint32_t)) : 0u;
|
|
3502
|
+
|
|
3503
|
+
const uint32_t total_size = buf_a_size + buf_b_size + mmid_row_ids + ballots_sh;
|
|
3504
|
+
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
3505
|
+
|
|
3506
|
+
VK_LOG_DEBUG("ggml_vk_matmul_int_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
|
|
3507
|
+
"mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", total=" << total_size << ", supported=" << supported);
|
|
3508
|
+
|
|
3509
|
+
return supported;
|
|
3510
|
+
}
|
|
3511
|
+
|
|
3081
3512
|
struct GpuPipelineConfig {
|
|
3082
3513
|
// GPU architecture identifier.
|
|
3083
3514
|
// Example: vk_device_architecture::AMD_GCN
|
|
@@ -3145,10 +3576,40 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev
|
|
|
3145
3576
|
return 0; // If no matching configuration is found
|
|
3146
3577
|
}
|
|
3147
3578
|
|
|
3148
|
-
|
|
3579
|
+
// Whether scalar flash attention will use the MMQ path for the given k_type.
|
|
3580
|
+
static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type) {
|
|
3581
|
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
3582
|
+
return device->integer_dot_product && device->subgroup_clustered &&
|
|
3583
|
+
(k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q4_1 ||
|
|
3584
|
+
k_type == GGML_TYPE_Q5_0 || k_type == GGML_TYPE_Q5_1 ||
|
|
3585
|
+
k_type == GGML_TYPE_Q8_0);
|
|
3586
|
+
#else
|
|
3587
|
+
GGML_UNUSED(device);
|
|
3588
|
+
GGML_UNUSED(k_type);
|
|
3589
|
+
return false;
|
|
3590
|
+
#endif
|
|
3591
|
+
}
|
|
3592
|
+
|
|
3593
|
+
// load_shaders walks the pipeline list under compile_mutex and either claims
|
|
3594
|
+
// the requested pipeline for compilation or, if another thread is already
|
|
3595
|
+
// compiling it, drops the lock and waits on compile_cv. Compiles themselves
|
|
3596
|
+
// run unlocked.
|
|
3597
|
+
struct CompileTask {
|
|
3598
|
+
vk_pipeline pipeline;
|
|
3599
|
+
size_t spv_size;
|
|
3600
|
+
const void * spv_data;
|
|
3601
|
+
std::string entrypoint;
|
|
3602
|
+
uint32_t parameter_count;
|
|
3603
|
+
std::array<uint32_t, 3> wg_denoms;
|
|
3604
|
+
std::vector<uint32_t> specialization_constants;
|
|
3605
|
+
bool disable_robustness;
|
|
3606
|
+
bool require_full_subgroups;
|
|
3607
|
+
uint32_t required_subgroup_size;
|
|
3608
|
+
};
|
|
3609
|
+
|
|
3610
|
+
static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
|
3149
3611
|
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
|
3150
3612
|
|
|
3151
|
-
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
3152
3613
|
// some shaders have a minimum subgroup size
|
|
3153
3614
|
const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
|
|
3154
3615
|
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
|
|
@@ -3178,6 +3639,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3178
3639
|
l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
|
|
3179
3640
|
|
|
3180
3641
|
uint32_t l_align, m_align, s_align;
|
|
3642
|
+
|
|
3643
|
+
vk_pipeline wait_pipeline;
|
|
3644
|
+
CompileTask claimed_task {};
|
|
3645
|
+
bool has_claimed_task = false;
|
|
3646
|
+
|
|
3647
|
+
// The rest of the walk reads and writes shared device state, so hold the
|
|
3648
|
+
// lock until we're done deciding what to compile.
|
|
3649
|
+
std::unique_lock<std::mutex> compile_lock(device->compile_mutex);
|
|
3650
|
+
|
|
3181
3651
|
if (device->coopmat2) {
|
|
3182
3652
|
// spec constants and tile sizes for non-quant matmul/matmul_id
|
|
3183
3653
|
l_warptile = { 256, 128, 256, 64, 1 };
|
|
@@ -3204,9 +3674,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3204
3674
|
s_mmq_wg_denoms_k = { 32, 64, 1 };
|
|
3205
3675
|
|
|
3206
3676
|
// spec constants and tile sizes for quant matmul_id
|
|
3207
|
-
|
|
3208
|
-
|
|
3209
|
-
|
|
3677
|
+
const uint32_t mmqid_bk = device->coopmat2_decode_vector ? 64u : 32u;
|
|
3678
|
+
l_warptile_mmqid = { 256, 128, 128, mmqid_bk, 1, device->subgroup_size };
|
|
3679
|
+
m_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size };
|
|
3680
|
+
s_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size };
|
|
3210
3681
|
l_mmqid_wg_denoms = { 128, 128, 1 };
|
|
3211
3682
|
m_mmqid_wg_denoms = { 128, 64, 1 };
|
|
3212
3683
|
s_mmqid_wg_denoms = { 128, 64, 1 };
|
|
@@ -3310,6 +3781,40 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3310
3781
|
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
|
|
3311
3782
|
device->mul_mat_id_l[i] = false;
|
|
3312
3783
|
}
|
|
3784
|
+
|
|
3785
|
+
// The q8_1 mmq path has its own (larger) shmem layout, check it separately.
|
|
3786
|
+
// K-quants use the _int_k warptiles, others use _int.
|
|
3787
|
+
const bool is_k_quant = (t == GGML_TYPE_Q2_K || t == GGML_TYPE_Q3_K ||
|
|
3788
|
+
t == GGML_TYPE_Q4_K || t == GGML_TYPE_Q5_K ||
|
|
3789
|
+
t == GGML_TYPE_Q6_K);
|
|
3790
|
+
const auto & s_int = is_k_quant ? s_warptile_mmq_int_k : s_warptile_mmq_int;
|
|
3791
|
+
const auto & m_int = is_k_quant ? m_warptile_mmq_int_k : m_warptile_mmq_int;
|
|
3792
|
+
const auto & l_int = is_k_quant ? l_warptile_mmq_int_k : l_warptile_mmq_int;
|
|
3793
|
+
const auto & s_intid = is_k_quant ? s_warptile_mmqid_int_k : s_warptile_mmqid_int;
|
|
3794
|
+
const auto & m_intid = is_k_quant ? m_warptile_mmqid_int_k : m_warptile_mmqid_int;
|
|
3795
|
+
const auto & l_intid = is_k_quant ? l_warptile_mmqid_int_k : l_warptile_mmqid_int;
|
|
3796
|
+
|
|
3797
|
+
if (!ggml_vk_matmul_int_shmem_support(device, s_int, false, t)) {
|
|
3798
|
+
device->mul_mat_s_int[i] = false;
|
|
3799
|
+
device->mul_mat_m_int[i] = false;
|
|
3800
|
+
device->mul_mat_l_int[i] = false;
|
|
3801
|
+
} else if (!ggml_vk_matmul_int_shmem_support(device, m_int, false, t)) {
|
|
3802
|
+
device->mul_mat_m_int[i] = false;
|
|
3803
|
+
device->mul_mat_l_int[i] = false;
|
|
3804
|
+
} else if (!ggml_vk_matmul_int_shmem_support(device, l_int, false, t)) {
|
|
3805
|
+
device->mul_mat_l_int[i] = false;
|
|
3806
|
+
}
|
|
3807
|
+
|
|
3808
|
+
if (!ggml_vk_matmul_int_shmem_support(device, s_intid, true, t)) {
|
|
3809
|
+
device->mul_mat_id_s_int[i] = false;
|
|
3810
|
+
device->mul_mat_id_m_int[i] = false;
|
|
3811
|
+
device->mul_mat_id_l_int[i] = false;
|
|
3812
|
+
} else if (!ggml_vk_matmul_int_shmem_support(device, m_intid, true, t)) {
|
|
3813
|
+
device->mul_mat_id_m_int[i] = false;
|
|
3814
|
+
device->mul_mat_id_l_int[i] = false;
|
|
3815
|
+
} else if (!ggml_vk_matmul_int_shmem_support(device, l_intid, true, t)) {
|
|
3816
|
+
device->mul_mat_id_l_int[i] = false;
|
|
3817
|
+
}
|
|
3313
3818
|
}
|
|
3314
3819
|
}
|
|
3315
3820
|
|
|
@@ -3329,7 +3834,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3329
3834
|
device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
3330
3835
|
}
|
|
3331
3836
|
|
|
3332
|
-
std::vector<std::future<void>> compiles;
|
|
3333
3837
|
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
|
|
3334
3838
|
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
|
|
3335
3839
|
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
|
|
@@ -3363,23 +3867,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3363
3867
|
#endif
|
|
3364
3868
|
}
|
|
3365
3869
|
|
|
3366
|
-
|
|
3870
|
+
// We only care about the pipeline this call asked for; the rest
|
|
3871
|
+
// (including the 64-bit indexing variant) are handled by their
|
|
3872
|
+
// own request_descriptor_sets / load_shaders calls.
|
|
3873
|
+
if (pipeline.get() != requested.get()) {
|
|
3367
3874
|
continue;
|
|
3368
3875
|
}
|
|
3369
|
-
|
|
3370
|
-
|
|
3371
|
-
|
|
3372
|
-
// wait until fewer than N compiles are in progress
|
|
3373
|
-
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
|
|
3374
|
-
std::unique_lock<std::mutex> guard(compile_count_mutex);
|
|
3375
|
-
while (compile_count >= N) {
|
|
3376
|
-
compile_count_cond.wait(guard);
|
|
3377
|
-
}
|
|
3378
|
-
compile_count++;
|
|
3876
|
+
|
|
3877
|
+
if (pipeline->compiled) {
|
|
3878
|
+
continue;
|
|
3379
3879
|
}
|
|
3380
3880
|
|
|
3381
|
-
|
|
3382
|
-
|
|
3881
|
+
wait_pipeline = pipeline;
|
|
3882
|
+
|
|
3883
|
+
if (!pipeline->compile_pending) {
|
|
3884
|
+
pipeline->compile_pending = true;
|
|
3885
|
+
claimed_task.pipeline = pipeline;
|
|
3886
|
+
claimed_task.spv_size = spv_size;
|
|
3887
|
+
claimed_task.spv_data = spv_data;
|
|
3888
|
+
claimed_task.entrypoint = entrypoint;
|
|
3889
|
+
claimed_task.parameter_count = parameter_count;
|
|
3890
|
+
claimed_task.wg_denoms = wg_denoms;
|
|
3891
|
+
claimed_task.specialization_constants = specialization_constants;
|
|
3892
|
+
claimed_task.disable_robustness = disable_robustness;
|
|
3893
|
+
claimed_task.require_full_subgroups = require_full_subgroups;
|
|
3894
|
+
claimed_task.required_subgroup_size = required_subgroup_size;
|
|
3895
|
+
has_claimed_task = true;
|
|
3896
|
+
}
|
|
3383
3897
|
}
|
|
3384
3898
|
};
|
|
3385
3899
|
|
|
@@ -3391,64 +3905,132 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3391
3905
|
align, disable_robustness, require_full_subgroups, required_subgroup_size);
|
|
3392
3906
|
};
|
|
3393
3907
|
|
|
3394
|
-
|
|
3395
|
-
|
|
3396
|
-
|
|
3397
|
-
|
|
3398
|
-
|
|
3399
|
-
|
|
3400
|
-
|
|
3401
|
-
|
|
3402
|
-
|
|
3403
|
-
|
|
3404
|
-
|
|
3405
|
-
|
|
3406
|
-
|
|
3407
|
-
|
|
3408
|
-
|
|
3409
|
-
|
|
3410
|
-
|
|
3411
|
-
|
|
3412
|
-
|
|
3413
|
-
|
|
3414
|
-
|
|
3415
|
-
|
|
3416
|
-
|
|
3417
|
-
|
|
3908
|
+
// FA scalar has two SPIR-V modules (MMQ vs non-MMQ); FA cm1 has one. K/V
|
|
3909
|
+
// quant type is selected at runtime via the FaTypeK / FaTypeV spec constants.
|
|
3910
|
+
|
|
3911
|
+
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
|
|
3912
|
+
if (fa.first.path != FA_SCALAR) continue;
|
|
3913
|
+
const uint32_t Br = fa.first.Br;
|
|
3914
|
+
const uint32_t Bc = fa.first.Bc;
|
|
3915
|
+
const bool aligned = fa.first.aligned;
|
|
3916
|
+
const bool f32acc = fa.first.f32acc;
|
|
3917
|
+
const uint32_t fa_sgs = fa.first.subgroup_size;
|
|
3918
|
+
const bool fa_ds = fa.first.subgroup_size == 0;
|
|
3919
|
+
|
|
3920
|
+
const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16;
|
|
3921
|
+
const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type);
|
|
3922
|
+
const void * spv_data = nullptr;
|
|
3923
|
+
size_t spv_size = 0;
|
|
3924
|
+
const char *name = nullptr;
|
|
3925
|
+
if (bf16_kv) {
|
|
3926
|
+
spv_data = flash_attn_f32_f16_fp32_data;
|
|
3927
|
+
spv_size = flash_attn_f32_f16_fp32_len;
|
|
3928
|
+
name = aligned ? "flash_attn_f32_bf16_aligned" : "flash_attn_f32_bf16";
|
|
3929
|
+
} else if (use_mmq) {
|
|
3930
|
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
3931
|
+
if (device->fp16) {
|
|
3932
|
+
if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; }
|
|
3933
|
+
else { spv_data = flash_attn_f32_f16_f16acc_int8_data; spv_size = flash_attn_f32_f16_f16acc_int8_len; }
|
|
3934
|
+
} else {
|
|
3935
|
+
spv_data = flash_attn_f32_f16_fp32_int8_data;
|
|
3936
|
+
spv_size = flash_attn_f32_f16_fp32_int8_len;
|
|
3937
|
+
}
|
|
3938
|
+
#endif
|
|
3939
|
+
name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
|
|
3940
|
+
} else {
|
|
3941
|
+
if (device->fp16) {
|
|
3942
|
+
if (device->dot2_f16) {
|
|
3943
|
+
if (f32acc) { spv_data = flash_attn_f32_f16_dot2_data; spv_size = flash_attn_f32_f16_dot2_len; }
|
|
3944
|
+
else { spv_data = flash_attn_f32_f16_dot2_f16acc_data; spv_size = flash_attn_f32_f16_dot2_f16acc_len; }
|
|
3945
|
+
} else {
|
|
3946
|
+
if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
|
|
3947
|
+
else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; }
|
|
3948
|
+
}
|
|
3949
|
+
} else {
|
|
3950
|
+
spv_data = flash_attn_f32_f16_fp32_data;
|
|
3951
|
+
spv_size = flash_attn_f32_f16_fp32_len;
|
|
3952
|
+
}
|
|
3953
|
+
name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
|
|
3418
3954
|
}
|
|
3419
|
-
|
|
3420
|
-
|
|
3421
|
-
|
|
3422
|
-
|
|
3423
|
-
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
|
3424
|
-
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
|
3425
|
-
} else {
|
|
3426
|
-
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
|
|
3427
|
-
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
|
|
3428
|
-
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
|
3429
|
-
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
|
3955
|
+
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
|
|
3956
|
+
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
|
|
3957
|
+
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
|
|
3958
|
+
!fa_ds, !fa_ds ? fa_sgs : 0);
|
|
3430
3959
|
}
|
|
3960
|
+
|
|
3431
3961
|
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
3432
3962
|
if (device->coopmat1_fa_support) {
|
|
3433
|
-
|
|
3434
|
-
|
|
3435
|
-
|
|
3436
|
-
|
|
3963
|
+
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
|
|
3964
|
+
if (fa.first.path != FA_COOPMAT1) continue;
|
|
3965
|
+
const uint32_t Br = fa.first.Br;
|
|
3966
|
+
const uint32_t Bc = fa.first.Bc;
|
|
3967
|
+
const bool aligned = fa.first.aligned;
|
|
3968
|
+
const bool f32acc = fa.first.f32acc;
|
|
3969
|
+
const uint32_t fa_sgs = fa.first.subgroup_size;
|
|
3970
|
+
const bool fa_ds = fa.first.subgroup_size == 0;
|
|
3971
|
+
|
|
3972
|
+
const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16;
|
|
3973
|
+
|
|
3974
|
+
const void * spv_data;
|
|
3975
|
+
size_t spv_size;
|
|
3976
|
+
const char *name;
|
|
3977
|
+
if (bf16_kv) {
|
|
3978
|
+
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
3979
|
+
if (!device->coopmat_bf16_support) continue;
|
|
3980
|
+
spv_data = flash_attn_f32_f16_bf16_cm1_data;
|
|
3981
|
+
spv_size = flash_attn_f32_f16_bf16_cm1_len;
|
|
3982
|
+
name = aligned ? "flash_attn_f32_bf16_aligned_cm1" : "flash_attn_f32_bf16_cm1";
|
|
3983
|
+
#else
|
|
3984
|
+
continue;
|
|
3985
|
+
#endif
|
|
3986
|
+
} else {
|
|
3987
|
+
if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; }
|
|
3988
|
+
else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; }
|
|
3989
|
+
name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1";
|
|
3990
|
+
}
|
|
3991
|
+
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
|
|
3992
|
+
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
|
|
3993
|
+
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
|
|
3994
|
+
!fa_ds, !fa_ds ? fa_sgs : 0);
|
|
3995
|
+
}
|
|
3437
3996
|
}
|
|
3438
3997
|
#endif
|
|
3998
|
+
|
|
3439
3999
|
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
3440
4000
|
if (device->coopmat2) {
|
|
3441
|
-
|
|
3442
|
-
|
|
3443
|
-
|
|
3444
|
-
|
|
3445
|
-
|
|
3446
|
-
|
|
3447
|
-
|
|
3448
|
-
|
|
4001
|
+
for (auto &fa : device->pipeline_flash_attn_f32_f16) {
|
|
4002
|
+
if (fa.first.path != FA_COOPMAT2) continue;
|
|
4003
|
+
const uint32_t Br = fa.first.Br;
|
|
4004
|
+
const uint32_t Bc = fa.first.Bc;
|
|
4005
|
+
const bool aligned = fa.first.aligned;
|
|
4006
|
+
const bool f32acc = fa.first.f32acc;
|
|
4007
|
+
|
|
4008
|
+
const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16;
|
|
4009
|
+
const void * spv_data;
|
|
4010
|
+
size_t spv_size;
|
|
4011
|
+
const char * name;
|
|
4012
|
+
if (bf16_kv) {
|
|
4013
|
+
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
4014
|
+
if (!device->coopmat2_bf16_support) continue;
|
|
4015
|
+
spv_data = flash_attn_f32_f16_bf16_cm2_data;
|
|
4016
|
+
spv_size = flash_attn_f32_f16_bf16_cm2_len;
|
|
4017
|
+
name = aligned ? "flash_attn_f32_bf16_aligned_cm2" : "flash_attn_f32_bf16_cm2";
|
|
4018
|
+
#else
|
|
4019
|
+
continue;
|
|
4020
|
+
#endif
|
|
4021
|
+
} else if (aligned) {
|
|
4022
|
+
if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; }
|
|
4023
|
+
else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; }
|
|
4024
|
+
} else {
|
|
4025
|
+
if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_f32acc_cm2"; }
|
|
4026
|
+
else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_f16acc_cm2"; }
|
|
4027
|
+
}
|
|
4028
|
+
ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
|
|
4029
|
+
sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
|
|
4030
|
+
get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, false, 0);
|
|
4031
|
+
}
|
|
3449
4032
|
}
|
|
3450
4033
|
#endif
|
|
3451
|
-
#undef CREATE_FA
|
|
3452
4034
|
|
|
3453
4035
|
const int mul_mat_id_param_count = 5;
|
|
3454
4036
|
|
|
@@ -3475,6 +4057,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3475
4057
|
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
3476
4058
|
}
|
|
3477
4059
|
#endif
|
|
4060
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q1_0], matmul_q1_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
3478
4061
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
3479
4062
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
3480
4063
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
@@ -3495,6 +4078,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3495
4078
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
3496
4079
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
3497
4080
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
4081
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_NVFP4], matmul_nvfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
3498
4082
|
|
|
3499
4083
|
GGML_ASSERT(device->subgroup_ballot);
|
|
3500
4084
|
|
|
@@ -3504,6 +4088,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3504
4088
|
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
|
|
3505
4089
|
}
|
|
3506
4090
|
#endif
|
|
4091
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3507
4092
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3508
4093
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3509
4094
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
@@ -3524,6 +4109,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3524
4109
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3525
4110
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3526
4111
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
4112
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3527
4113
|
#undef CREATE_MM
|
|
3528
4114
|
#undef CREATE_MM2
|
|
3529
4115
|
} else
|
|
@@ -3565,6 +4151,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3565
4151
|
#endif
|
|
3566
4152
|
|
|
3567
4153
|
if (device->coopmat_acc_f16_support) {
|
|
4154
|
+
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
3568
4155
|
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
3569
4156
|
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
3570
4157
|
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
@@ -3586,7 +4173,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3586
4173
|
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
3587
4174
|
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
3588
4175
|
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
4176
|
+
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
3589
4177
|
} else {
|
|
4178
|
+
CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
3590
4179
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
3591
4180
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
3592
4181
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
@@ -3608,6 +4197,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3608
4197
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
3609
4198
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
3610
4199
|
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
4200
|
+
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
3611
4201
|
}
|
|
3612
4202
|
|
|
3613
4203
|
GGML_ASSERT(device->subgroup_ballot);
|
|
@@ -3621,6 +4211,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3621
4211
|
}
|
|
3622
4212
|
#endif
|
|
3623
4213
|
|
|
4214
|
+
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3624
4215
|
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3625
4216
|
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3626
4217
|
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
@@ -3641,13 +4232,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3641
4232
|
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3642
4233
|
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3643
4234
|
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
4235
|
+
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3644
4236
|
#undef CREATE_MM2
|
|
3645
4237
|
#undef CREATE_MM
|
|
3646
4238
|
} else
|
|
3647
4239
|
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
3648
4240
|
if (device->fp16) {
|
|
3649
4241
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
4242
|
+
// Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
|
|
3650
4243
|
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
|
4244
|
+
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
4245
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
4246
|
+
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
4247
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
4248
|
+
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
4249
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
4250
|
+
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
4251
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
4252
|
+
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
4253
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
4254
|
+
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
4255
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
4256
|
+
|
|
4257
|
+
// bf16 scalar path promotes to f32, no dot2 variant
|
|
4258
|
+
#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
|
3651
4259
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
3652
4260
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
3653
4261
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
@@ -3662,13 +4270,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3662
4270
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
3663
4271
|
|
|
3664
4272
|
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
|
3665
|
-
if (device->mul_mat ## ID ##
|
|
4273
|
+
if (device->mul_mat ## ID ## _l_int[TYPE]) { \
|
|
3666
4274
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
3667
4275
|
} \
|
|
3668
|
-
if (device->mul_mat ## ID ##
|
|
4276
|
+
if (device->mul_mat ## ID ## _m_int[TYPE]) { \
|
|
3669
4277
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
3670
4278
|
} \
|
|
3671
|
-
if (device->mul_mat ## ID ##
|
|
4279
|
+
if (device->mul_mat ## ID ## _s_int[TYPE]) { \
|
|
3672
4280
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
3673
4281
|
} \
|
|
3674
4282
|
|
|
@@ -3682,14 +4290,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3682
4290
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
3683
4291
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
3684
4292
|
|
|
3685
|
-
|
|
4293
|
+
CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
3686
4294
|
|
|
4295
|
+
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3687
4296
|
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3688
4297
|
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3689
4298
|
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3690
4299
|
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3691
4300
|
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3692
|
-
|
|
3693
4301
|
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3694
4302
|
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3695
4303
|
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
@@ -3705,6 +4313,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3705
4313
|
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3706
4314
|
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3707
4315
|
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
4316
|
+
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3708
4317
|
|
|
3709
4318
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
3710
4319
|
if (device->integer_dot_product) {
|
|
@@ -3728,8 +4337,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3728
4337
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3729
4338
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3730
4339
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3731
|
-
|
|
3732
|
-
|
|
4340
|
+
CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
4341
|
+
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3733
4342
|
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3734
4343
|
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3735
4344
|
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
@@ -3750,6 +4359,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3750
4359
|
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3751
4360
|
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3752
4361
|
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
4362
|
+
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3753
4363
|
|
|
3754
4364
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
3755
4365
|
if (device->integer_dot_product) {
|
|
@@ -3772,8 +4382,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3772
4382
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3773
4383
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3774
4384
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3775
|
-
|
|
3776
|
-
|
|
4385
|
+
CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
4386
|
+
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3777
4387
|
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3778
4388
|
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3779
4389
|
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
@@ -3794,6 +4404,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3794
4404
|
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3795
4405
|
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3796
4406
|
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
4407
|
+
CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3797
4408
|
|
|
3798
4409
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
3799
4410
|
if (device->integer_dot_product) {
|
|
@@ -3816,6 +4427,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3816
4427
|
#undef CREATE_MM2
|
|
3817
4428
|
#undef CREATE_MMQ
|
|
3818
4429
|
#undef CREATE_MM
|
|
4430
|
+
#undef CREATE_MM_NODOT2
|
|
3819
4431
|
} else {
|
|
3820
4432
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
3821
4433
|
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
|
@@ -3833,11 +4445,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3833
4445
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
3834
4446
|
|
|
3835
4447
|
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
3836
|
-
if (device->mul_mat ## ID ##
|
|
4448
|
+
if (device->mul_mat ## ID ## _l_int[TYPE]) \
|
|
3837
4449
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
3838
|
-
if (device->mul_mat ## ID ##
|
|
4450
|
+
if (device->mul_mat ## ID ## _m_int[TYPE]) \
|
|
3839
4451
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
3840
|
-
if (device->mul_mat ## ID ##
|
|
4452
|
+
if (device->mul_mat ## ID ## _s_int[TYPE]) \
|
|
3841
4453
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
3842
4454
|
|
|
3843
4455
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
@@ -3847,6 +4459,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3847
4459
|
|
|
3848
4460
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
3849
4461
|
|
|
4462
|
+
CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3850
4463
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3851
4464
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3852
4465
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
@@ -3868,6 +4481,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3868
4481
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3869
4482
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3870
4483
|
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
4484
|
+
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
3871
4485
|
|
|
3872
4486
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
3873
4487
|
if (device->integer_dot_product) {
|
|
@@ -3891,6 +4505,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3891
4505
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3892
4506
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3893
4507
|
|
|
4508
|
+
CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_subgroup_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3894
4509
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3895
4510
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3896
4511
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
@@ -3911,12 +4526,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3911
4526
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3912
4527
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3913
4528
|
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
4529
|
+
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_subgroup_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3914
4530
|
} else {
|
|
3915
4531
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3916
4532
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3917
4533
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3918
4534
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3919
4535
|
|
|
4536
|
+
CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3920
4537
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3921
4538
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3922
4539
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
@@ -3937,6 +4554,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3937
4554
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3938
4555
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3939
4556
|
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
4557
|
+
CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3940
4558
|
}
|
|
3941
4559
|
}
|
|
3942
4560
|
// reusing CREATE_MM from the fp32 path
|
|
@@ -3956,11 +4574,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3956
4574
|
m_wg_denoms = { 64, 64, 1 };
|
|
3957
4575
|
s_wg_denoms = { 32, 32, 1 };
|
|
3958
4576
|
|
|
3959
|
-
if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) {
|
|
3960
|
-
// Xe2/Xe3 - bf16 warptile performance tuning
|
|
3961
|
-
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 };
|
|
3962
|
-
}
|
|
3963
|
-
|
|
3964
4577
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
3965
4578
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3966
4579
|
}
|
|
@@ -4014,6 +4627,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4014
4627
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
|
4015
4628
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
|
4016
4629
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
|
4630
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f32_f32", arr_dmmv_q1_0_f32_f32_len[reduc], arr_dmmv_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
|
4017
4631
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
|
4018
4632
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
|
4019
4633
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
|
@@ -4034,10 +4648,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4034
4648
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
|
4035
4649
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
|
4036
4650
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
|
4651
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
|
4037
4652
|
|
|
4038
4653
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
|
4039
4654
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
|
4040
4655
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
|
|
4656
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f16_f32", arr_dmmv_q1_0_f16_f32_len[reduc], arr_dmmv_q1_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
|
4041
4657
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
|
4042
4658
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
|
4043
4659
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
|
|
@@ -4058,6 +4674,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4058
4674
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
|
4059
4675
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
|
4060
4676
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
|
4677
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
|
|
4061
4678
|
|
|
4062
4679
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
4063
4680
|
if (device->integer_dot_product) {
|
|
@@ -4088,6 +4705,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4088
4705
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size);
|
|
4089
4706
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
|
4090
4707
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
|
4708
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q1_0], "mul_mat_vec_id_q1_0_f32", arr_dmmv_id_q1_0_f32_f32_len[reduc], arr_dmmv_id_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
|
4091
4709
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
|
4092
4710
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
|
4093
4711
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
|
@@ -4108,6 +4726,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4108
4726
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
|
4109
4727
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
|
4110
4728
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
|
4729
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_NVFP4], "mul_mat_vec_id_nvfp4_f32", arr_dmmv_id_nvfp4_f32_f32_len[reduc16], arr_dmmv_id_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
|
4111
4730
|
|
|
4112
4731
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
4113
4732
|
if (device->integer_dot_product) {
|
|
@@ -4142,6 +4761,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4142
4761
|
|
|
4143
4762
|
// dequant shaders
|
|
4144
4763
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
4764
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q1_0], "dequant_q1_0", dequant_q1_0_len, dequant_q1_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 8, 1, 1}, {}, 1);
|
|
4145
4765
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
4146
4766
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
4147
4767
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
@@ -4162,11 +4782,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4162
4782
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
|
4163
4783
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
4164
4784
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
4785
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
4165
4786
|
|
|
4166
4787
|
// get_rows
|
|
4167
4788
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
4168
4789
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
4169
4790
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
4791
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q1_0], "get_rows_q1_0", get_rows_q1_0_len, get_rows_q1_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4170
4792
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4171
4793
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4172
4794
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
@@ -4187,11 +4809,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4187
4809
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4188
4810
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4189
4811
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4812
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_NVFP4], "get_rows_nvfp4", get_rows_nvfp4_len, get_rows_nvfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4190
4813
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4191
4814
|
|
|
4192
4815
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
4193
4816
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
4194
4817
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
4818
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q1_0], "get_rows_q1_0_f32", get_rows_q1_0_f32_len, get_rows_q1_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4195
4819
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4196
4820
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4197
4821
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
@@ -4212,6 +4836,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4212
4836
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4213
4837
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4214
4838
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4839
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
4215
4840
|
|
|
4216
4841
|
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
|
4217
4842
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
|
@@ -4244,10 +4869,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4244
4869
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
|
|
4245
4870
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
|
4246
4871
|
|
|
4247
|
-
if (device->
|
|
4248
|
-
sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
|
|
4872
|
+
if (sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
|
|
4249
4873
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
|
4250
|
-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16",
|
|
4874
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_len, rms_norm_mul_rope_f32_f16_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
|
4251
4875
|
}
|
|
4252
4876
|
|
|
4253
4877
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
@@ -4258,6 +4882,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4258
4882
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4259
4883
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4260
4884
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4885
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_bf16_f32,"cpy_bf16_f32",cpy_bf16_f32_len,cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4261
4886
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4262
4887
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4263
4888
|
|
|
@@ -4266,49 +4891,39 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4266
4891
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4267
4892
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4268
4893
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4894
|
+
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_bf16_f32,"contig_cpy_bf16_f32",contig_cpy_bf16_f32_len,contig_cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4269
4895
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4270
4896
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4271
4897
|
|
|
4272
4898
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
|
4273
4899
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
|
4274
4900
|
|
|
4275
|
-
|
|
4276
|
-
|
|
4277
|
-
|
|
4278
|
-
|
|
4279
|
-
|
|
4280
|
-
|
|
4281
|
-
|
|
4282
|
-
|
|
4283
|
-
|
|
4284
|
-
ggml_vk_create_pipeline(device, device->
|
|
4285
|
-
ggml_vk_create_pipeline(device, device->
|
|
4286
|
-
ggml_vk_create_pipeline(device, device->
|
|
4287
|
-
ggml_vk_create_pipeline(device, device->
|
|
4288
|
-
ggml_vk_create_pipeline(device, device->
|
|
4289
|
-
|
|
4290
|
-
|
|
4291
|
-
#
|
|
4292
|
-
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [
|
|
4293
|
-
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [
|
|
4294
|
-
|
|
4295
|
-
|
|
4296
|
-
|
|
4297
|
-
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
|
4298
|
-
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
|
4299
|
-
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
|
4300
|
-
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
4301
|
-
|
|
4302
|
-
if (device->float_controls_rte_fp16) {
|
|
4303
|
-
SET_ROWS(_i32, _rte)
|
|
4304
|
-
SET_ROWS(_i64, _rte)
|
|
4305
|
-
} else {
|
|
4306
|
-
SET_ROWS(_i32, )
|
|
4307
|
-
SET_ROWS(_i64, )
|
|
4308
|
-
}
|
|
4901
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
4902
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
4903
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
4904
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
4905
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
4906
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
4907
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
4908
|
+
|
|
4909
|
+
#define SET_ROWS(itype) \
|
|
4910
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## _len, set_rows_f32 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
|
4911
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## _len, set_rows_f16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
|
4912
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## _len, set_rows_bf16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
|
4913
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## _len, set_rows_q1_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
|
4914
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## _len, set_rows_q4_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
|
4915
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## _len, set_rows_q4_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
|
4916
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## _len, set_rows_q5_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
|
4917
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
|
4918
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
|
|
4919
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
4920
|
+
|
|
4921
|
+
SET_ROWS(_i32)
|
|
4922
|
+
SET_ROWS(_i64)
|
|
4309
4923
|
#undef SET_ROWS
|
|
4310
4924
|
|
|
4311
4925
|
|
|
4926
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q1_0], "cpy_q1_0_f32", cpy_q1_0_f32_len, cpy_q1_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q1_0), 1, 1}, {}, 1);
|
|
4312
4927
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
|
4313
4928
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
|
4314
4929
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
|
|
@@ -4324,11 +4939,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4324
4939
|
return s;
|
|
4325
4940
|
};
|
|
4326
4941
|
|
|
4327
|
-
bool rte = device->float_controls_rte_fp16;
|
|
4328
4942
|
#define CREATE_BINARY(name, namemod, spec, bindings) \
|
|
4329
4943
|
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
|
|
4330
4944
|
ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
|
|
4331
|
-
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d]
|
|
4945
|
+
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
|
|
4332
4946
|
"main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
|
|
4333
4947
|
|
|
4334
4948
|
CREATE_BINARY(add, , {0}, 4)
|
|
@@ -4371,13 +4985,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4371
4985
|
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4372
4986
|
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4373
4987
|
|
|
4374
|
-
|
|
4375
|
-
|
|
4376
|
-
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4377
|
-
} else {
|
|
4378
|
-
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4379
|
-
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4380
|
-
}
|
|
4988
|
+
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4989
|
+
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4381
4990
|
|
|
4382
4991
|
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4383
4992
|
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -4391,9 +5000,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4391
5000
|
|
|
4392
5001
|
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4393
5002
|
|
|
4394
|
-
ggml_vk_create_pipeline(device, device->
|
|
5003
|
+
ggml_vk_create_pipeline(device, device->pipeline_repeat_i32, "repeat_i32", repeat_i32_len, repeat_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4395
5004
|
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4396
5005
|
|
|
5006
|
+
ggml_vk_create_pipeline(device, device->pipeline_repeat_i16, "repeat_i16", repeat_i16_len, repeat_i16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
5007
|
+
|
|
4397
5008
|
#define CREATE_UNARY(name) \
|
|
4398
5009
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
|
4399
5010
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -4418,19 +5029,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4418
5029
|
CREATE_UNARY(floor)
|
|
4419
5030
|
CREATE_UNARY(trunc)
|
|
4420
5031
|
CREATE_UNARY(sgn)
|
|
5032
|
+
CREATE_UNARY(exp)
|
|
4421
5033
|
#undef CREATE_UNARY
|
|
4422
5034
|
|
|
4423
|
-
#define CREATE_UNARY_RTE(name) \
|
|
4424
|
-
if (device->float_controls_rte_fp16) { \
|
|
4425
|
-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
|
4426
|
-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
|
4427
|
-
} else { \
|
|
4428
|
-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
|
4429
|
-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
|
4430
|
-
}
|
|
4431
|
-
CREATE_UNARY_RTE(exp)
|
|
4432
|
-
#undef CREATE_UNARY_RTE
|
|
4433
|
-
|
|
4434
5035
|
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
4435
5036
|
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
4436
5037
|
ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -4438,15 +5039,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4438
5039
|
ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
4439
5040
|
|
|
4440
5041
|
ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
5042
|
+
ggml_vk_create_pipeline(device, device->pipeline_fill_f16, "fill_f16", fill_f16_len, fill_f16_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
4441
5043
|
|
|
4442
5044
|
#define CREATE_GLU(name) \
|
|
4443
|
-
|
|
4444
|
-
|
|
4445
|
-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
|
4446
|
-
} else { \
|
|
4447
|
-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
|
4448
|
-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
|
4449
|
-
}
|
|
5045
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
|
5046
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
|
|
4450
5047
|
|
|
4451
5048
|
CREATE_GLU(geglu)
|
|
4452
5049
|
CREATE_GLU(reglu)
|
|
@@ -4479,25 +5076,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4479
5076
|
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4480
5077
|
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4481
5078
|
|
|
4482
|
-
|
|
4483
|
-
|
|
4484
|
-
|
|
4485
|
-
|
|
4486
|
-
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4487
|
-
|
|
4488
|
-
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4489
|
-
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4490
|
-
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4491
|
-
} else {
|
|
4492
|
-
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4493
|
-
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4494
|
-
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4495
|
-
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
5079
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
5080
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
5081
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
5082
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4496
5083
|
|
|
4497
|
-
|
|
4498
|
-
|
|
4499
|
-
|
|
4500
|
-
}
|
|
5084
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
5085
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
5086
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4501
5087
|
|
|
4502
5088
|
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
|
|
4503
5089
|
uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
|
|
@@ -4531,6 +5117,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4531
5117
|
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
4532
5118
|
|
|
4533
5119
|
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
5120
|
+
// Intel Arc B390 was observed segfaulting with this shader.
|
|
5121
|
+
if (device->subgroup_basic && device->subgroup_shuffle && device->vendor_id != VK_VENDOR_ID_INTEL) {
|
|
5122
|
+
int idx = 0;
|
|
5123
|
+
for (uint32_t n : {64, 128, 256, 512}) {
|
|
5124
|
+
if (device->subgroup_size <= n) {
|
|
5125
|
+
ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_f32", fwht_f32_len, fwht_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { device->subgroup_size, n }, 1, true, true, device->subgroup_size);
|
|
5126
|
+
}
|
|
5127
|
+
++idx;
|
|
5128
|
+
}
|
|
5129
|
+
} else if (device->driver_id != vk::DriverId::eIntelProprietaryWindows) {
|
|
5130
|
+
// Disabled on Intel Windows due to a driver bug: https://github.com/ggml-org/llama.cpp/pull/23964#issuecomment-4598226147
|
|
5131
|
+
int idx = 0;
|
|
5132
|
+
for (uint32_t n : {64, 128, 256, 512}) {
|
|
5133
|
+
const uint32_t block_size = std::min(device->subgroup_size, n);
|
|
5134
|
+
ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_shmem_f32", fwht_shmem_f32_len, fwht_shmem_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { block_size, n }, 1);
|
|
5135
|
+
++idx;
|
|
5136
|
+
}
|
|
5137
|
+
}
|
|
4534
5138
|
|
|
4535
5139
|
const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
|
|
4536
5140
|
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
|
|
@@ -4559,13 +5163,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4559
5163
|
#define IM2COL(bda) \
|
|
4560
5164
|
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
|
4561
5165
|
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
|
4562
|
-
|
|
4563
|
-
|
|
4564
|
-
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
|
4565
|
-
} else { \
|
|
4566
|
-
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
|
4567
|
-
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
|
4568
|
-
}
|
|
5166
|
+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
|
5167
|
+
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
|
4569
5168
|
if (device->shader_int64 && device->buffer_device_address) {
|
|
4570
5169
|
IM2COL(_bda)
|
|
4571
5170
|
} else {
|
|
@@ -4576,6 +5175,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4576
5175
|
|
|
4577
5176
|
ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
|
|
4578
5177
|
|
|
5178
|
+
ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
|
|
5179
|
+
ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
|
|
5180
|
+
ggml_vk_create_pipeline(device, device->pipeline_snake_bf16, "snake_bf16", snake_bf16_len, snake_bf16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
|
|
5181
|
+
|
|
4579
5182
|
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
|
|
4580
5183
|
|
|
4581
5184
|
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
|
@@ -4589,12 +5192,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4589
5192
|
{"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
|
|
4590
5193
|
{"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
|
|
4591
5194
|
};
|
|
5195
|
+
const bool use_subgroup_reduce = device->subgroup_arithmetic;
|
|
4592
5196
|
for (uint32_t si = 0; si < 3; si++) {
|
|
5197
|
+
const uint32_t S_V = gdn_sizes[si];
|
|
5198
|
+
GGML_ASSERT(is_pow2(S_V));
|
|
5199
|
+
|
|
5200
|
+
uint32_t lanes_per_column;
|
|
5201
|
+
if (S_V >= 128u && device->subgroup_clustered) {
|
|
5202
|
+
lanes_per_column = 8u;
|
|
5203
|
+
} else {
|
|
5204
|
+
// Use largest power-of-two that divides both S_V and subgroup_size so that
|
|
5205
|
+
// (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0.
|
|
5206
|
+
// This means we don't need extra bounds checking logic in the shader.
|
|
5207
|
+
lanes_per_column = std::min(S_V, device->subgroup_size);
|
|
5208
|
+
}
|
|
5209
|
+
|
|
5210
|
+
const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size);
|
|
5211
|
+
size_t gdn_len;
|
|
5212
|
+
const void * gdn_data;
|
|
5213
|
+
if (use_subgroup_reduce && need_clustered_shader) {
|
|
5214
|
+
gdn_len = gated_delta_net_f32_len;
|
|
5215
|
+
gdn_data = (const void *)gated_delta_net_f32_data;
|
|
5216
|
+
} else if (use_subgroup_reduce) {
|
|
5217
|
+
gdn_len = gated_delta_net_f32_nocluster_len;
|
|
5218
|
+
gdn_data = (const void *)gated_delta_net_f32_nocluster_data;
|
|
5219
|
+
} else {
|
|
5220
|
+
gdn_len = gated_delta_net_f32_shmem_len;
|
|
5221
|
+
gdn_data = (const void *)gated_delta_net_f32_shmem_data;
|
|
5222
|
+
}
|
|
5223
|
+
|
|
5224
|
+
const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column;
|
|
5225
|
+
const std::array<uint32_t, 3> wg_denoms = {1u, 1u, cols_per_wg};
|
|
5226
|
+
|
|
4593
5227
|
for (uint32_t kda = 0; kda < 2; kda++) {
|
|
4594
5228
|
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
|
|
4595
|
-
gdn_names[si][kda],
|
|
4596
|
-
|
|
4597
|
-
{1, 1, 1}, {gdn_sizes[si], kda}, 1);
|
|
5229
|
+
gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
|
|
5230
|
+
wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size);
|
|
4598
5231
|
}
|
|
4599
5232
|
}
|
|
4600
5233
|
}
|
|
@@ -4607,7 +5240,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4607
5240
|
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
|
4608
5241
|
}
|
|
4609
5242
|
|
|
4610
|
-
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32,
|
|
5243
|
+
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 0}, 1);
|
|
5244
|
+
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_silu_f32, "ssm_conv_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 1}, 1);
|
|
5245
|
+
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_bias_silu_f32, "ssm_conv_bias_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 1, 1}, 1);
|
|
4611
5246
|
|
|
4612
5247
|
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
4613
5248
|
|
|
@@ -4615,7 +5250,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4615
5250
|
|
|
4616
5251
|
// conv2d, conv_transpose_2d
|
|
4617
5252
|
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
|
4618
|
-
|
|
5253
|
+
// smaller WG for the small-tile fallback gives more concurrent WGs per SM
|
|
5254
|
+
uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
|
|
4619
5255
|
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
|
|
4620
5256
|
uint32_t conv2d_TS_K = (s == CONV_SHAPE_64x32) ? 4 : 8;
|
|
4621
5257
|
uint32_t conv2d_SHMEM_PAD = 4;
|
|
@@ -4654,18 +5290,77 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4654
5290
|
conv2d_BS.CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
|
|
4655
5291
|
}
|
|
4656
5292
|
|
|
4657
|
-
|
|
4658
|
-
|
|
4659
|
-
|
|
5293
|
+
// cm1 is used only when cm2 is unavailable; capped at 64x128 (due to shared memory size).
|
|
5294
|
+
// Requires 16x16x16 f16-acc since that's the fragment shape hard-coded in the shader.
|
|
5295
|
+
// Subgroup size must be 32 or 64 (to keep WG_SIZE sane) and we need
|
|
5296
|
+
// subgroup_size_control to force the driver to actually use it.
|
|
5297
|
+
bool conv2d_use_cm1 = false;
|
|
5298
|
+
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
5299
|
+
conv2d_use_cm1 = !device->coopmat2 &&
|
|
5300
|
+
device->coopmat_support && device->coopmat_support_16x16x16_f16acc &&
|
|
5301
|
+
device->subgroup_size_control &&
|
|
5302
|
+
(device->subgroup_size == 32 || device->subgroup_size == 64) &&
|
|
5303
|
+
s != CONV_SHAPE_128x128;
|
|
5304
|
+
#endif
|
|
5305
|
+
|
|
5306
|
+
const uint32_t conv2d_cm1_shmem_pad = 8;
|
|
5307
|
+
|
|
5308
|
+
auto shmem_req = [&](uint32_t pad, bool csh_store, bool fp16_shmem) {
|
|
5309
|
+
const uint32_t elem_size = fp16_shmem ? (uint32_t)sizeof(uint16_t) : (uint32_t)sizeof(float);
|
|
5310
|
+
const uint32_t csh_elems = csh_store ? conv2d_BS.K * conv2d_BS.NPQ : 0u;
|
|
5311
|
+
return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
|
|
5312
|
+
};
|
|
5313
|
+
|
|
5314
|
+
// coopmat1 needs to store the output through shared memory, so check up front
|
|
5315
|
+
// whether it'll fit and disable it before applying coopmat1 parameters.
|
|
5316
|
+
if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
|
|
5317
|
+
conv2d_use_cm1 = false;
|
|
5318
|
+
}
|
|
5319
|
+
|
|
5320
|
+
uint32_t conv2d_WM = 16, conv2d_WN = 16; // cm1 subgroup tile, ignored otherwise
|
|
5321
|
+
if (conv2d_use_cm1) {
|
|
5322
|
+
conv2d_SHMEM_PAD = conv2d_cm1_shmem_pad;
|
|
5323
|
+
// 16x16x16 fragments; pick WM/WN to keep WG_SIZE at 256
|
|
5324
|
+
// (i.e. 8 subgroups for sg=32, 4 subgroups for sg=64).
|
|
5325
|
+
const bool sg64 = (device->subgroup_size == 64);
|
|
5326
|
+
switch (s) {
|
|
5327
|
+
case CONV_SHAPE_64x32: conv2d_WM = sg64 ? 32 : 16; conv2d_WN = 16; break;
|
|
5328
|
+
case CONV_SHAPE_64x128: conv2d_WM = 32; conv2d_WN = sg64 ? 64 : 32; break;
|
|
5329
|
+
case CONV_SHAPE_32x256: conv2d_WM = sg64 ? 16 : 32; conv2d_WN = sg64 ? 128 : 32; break;
|
|
5330
|
+
default: break;
|
|
5331
|
+
}
|
|
5332
|
+
const uint32_t warps_M = conv2d_BS.K / conv2d_WM;
|
|
5333
|
+
const uint32_t warps_N = conv2d_BS.NPQ / conv2d_WN;
|
|
5334
|
+
conv2d_WG_SIZE = warps_M * warps_N * device->subgroup_size;
|
|
5335
|
+
}
|
|
5336
|
+
|
|
5337
|
+
// stage cm2 accumulator through shmem for coalesced global stores;
|
|
5338
|
+
// skipped on 128x128 where the extra Csh footprint hurts occupancy.
|
|
5339
|
+
// cm1 always uses the staged path.
|
|
5340
|
+
uint32_t conv2d_csh_store = (device->coopmat2 && s != CONV_SHAPE_128x128) ? 1u : 0u;
|
|
5341
|
+
if (conv2d_use_cm1) {
|
|
5342
|
+
conv2d_csh_store = 1;
|
|
5343
|
+
}
|
|
5344
|
+
|
|
5345
|
+
// shmem is fp16 on cm2/cm1 (matches Csh), fp32 on scalar
|
|
5346
|
+
const bool conv2d_use_fp16_shmem = device->coopmat2 || conv2d_use_cm1;
|
|
5347
|
+
|
|
5348
|
+
// shrink CRS if the non-cm1 config still doesn't fit
|
|
5349
|
+
if (device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_SHMEM_PAD, conv2d_csh_store, conv2d_use_fp16_shmem)) {
|
|
5350
|
+
GGML_ASSERT(!conv2d_use_cm1);
|
|
4660
5351
|
conv2d_BS.CRS = 8;
|
|
4661
5352
|
if (use_collectives) {
|
|
4662
5353
|
conv2d_BS.CRS = std::min(device->subgroup_size, conv2d_BS.CRS);
|
|
4663
5354
|
}
|
|
5355
|
+
conv2d_csh_store = 0;
|
|
4664
5356
|
}
|
|
4665
5357
|
|
|
4666
5358
|
std::array<uint32_t, 3> wg_denoms = { conv2d_BS.K, 1, 1 };
|
|
4667
5359
|
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
|
|
4668
5360
|
|
|
5361
|
+
// cm1 needs a fixed subgroup width to match the WG_SIZE we computed
|
|
5362
|
+
const uint32_t conv2d_required_subgroup_size = conv2d_use_cm1 ? device->subgroup_size : 0;
|
|
5363
|
+
|
|
4669
5364
|
#define CREATE_CONV(name, type_suffix, spv_suffix) \
|
|
4670
5365
|
for (auto &c : device->pipeline_##name##type_suffix[s]) { \
|
|
4671
5366
|
const vk_conv2d_pipeline_state &state = c.first; \
|
|
@@ -4678,10 +5373,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4678
5373
|
spec_constants_cpy.push_back(state.d1); \
|
|
4679
5374
|
spec_constants_cpy.push_back(state.KW); \
|
|
4680
5375
|
spec_constants_cpy.push_back(state.KH); \
|
|
5376
|
+
spec_constants_cpy.push_back(state.aligned); \
|
|
5377
|
+
spec_constants_cpy.push_back(conv2d_csh_store); \
|
|
5378
|
+
spec_constants_cpy.push_back(conv2d_WM); \
|
|
5379
|
+
spec_constants_cpy.push_back(conv2d_WN); \
|
|
4681
5380
|
ggml_vk_create_pipeline( \
|
|
4682
5381
|
device, c.second, #name #type_suffix, \
|
|
4683
5382
|
name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
|
|
4684
|
-
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \
|
|
5383
|
+
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives || conv2d_required_subgroup_size, conv2d_required_subgroup_size); \
|
|
4685
5384
|
}
|
|
4686
5385
|
#define CREATE_CONVS(spv_suffix) \
|
|
4687
5386
|
CREATE_CONV(conv2d, _f32, spv_suffix) \
|
|
@@ -4692,6 +5391,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4692
5391
|
if (device->coopmat2) {
|
|
4693
5392
|
CREATE_CONVS(_cm2)
|
|
4694
5393
|
} else
|
|
5394
|
+
#endif
|
|
5395
|
+
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
5396
|
+
if (conv2d_use_cm1) {
|
|
5397
|
+
CREATE_CONVS(_cm1)
|
|
5398
|
+
} else
|
|
4695
5399
|
#endif
|
|
4696
5400
|
if (conv2d_UNROLL) {
|
|
4697
5401
|
CREATE_CONVS(_unroll)
|
|
@@ -4713,8 +5417,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4713
5417
|
}
|
|
4714
5418
|
}
|
|
4715
5419
|
|
|
4716
|
-
|
|
4717
|
-
|
|
5420
|
+
// Drop compile_mutex so other threads can walk while we compile.
|
|
5421
|
+
compile_lock.unlock();
|
|
5422
|
+
|
|
5423
|
+
// Compile what we claimed; create_pipeline_func reacquires compile_mutex
|
|
5424
|
+
// at the end to flip compile_pending/compiled and notify waiters.
|
|
5425
|
+
if (has_claimed_task) {
|
|
5426
|
+
auto & task = claimed_task;
|
|
5427
|
+
ggml_vk_create_pipeline_func(device, task.pipeline, task.spv_size, task.spv_data,
|
|
5428
|
+
task.entrypoint, task.parameter_count, task.wg_denoms,
|
|
5429
|
+
task.specialization_constants, task.disable_robustness,
|
|
5430
|
+
task.require_full_subgroups, task.required_subgroup_size);
|
|
5431
|
+
}
|
|
5432
|
+
|
|
5433
|
+
// Another thread may be compiling the pipeline we need; block on it here.
|
|
5434
|
+
if (wait_pipeline) {
|
|
5435
|
+
std::unique_lock<std::mutex> wait_lock(device->compile_mutex);
|
|
5436
|
+
device->compile_cv.wait(wait_lock, [&] {
|
|
5437
|
+
return wait_pipeline->compiled.load();
|
|
5438
|
+
});
|
|
4718
5439
|
}
|
|
4719
5440
|
}
|
|
4720
5441
|
|
|
@@ -4764,11 +5485,13 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
4764
5485
|
bool amd_shader_core_properties2 = false;
|
|
4765
5486
|
bool pipeline_robustness = false;
|
|
4766
5487
|
bool coopmat2_support = false;
|
|
5488
|
+
bool coopmat2_decode_vector_support = false;
|
|
4767
5489
|
bool pipeline_executable_properties_support = false;
|
|
4768
5490
|
device->coopmat_support = false;
|
|
4769
5491
|
device->integer_dot_product = false;
|
|
4770
5492
|
device->shader_64b_indexing = false;
|
|
4771
5493
|
bool bfloat16_support = false;
|
|
5494
|
+
bool dot2_f16_support = false;
|
|
4772
5495
|
|
|
4773
5496
|
for (const auto& properties : ext_props) {
|
|
4774
5497
|
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
|
@@ -4798,6 +5521,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
4798
5521
|
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
|
4799
5522
|
coopmat2_support = true;
|
|
4800
5523
|
#endif
|
|
5524
|
+
} else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 &&
|
|
5525
|
+
!getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) {
|
|
5526
|
+
coopmat2_decode_vector_support = true;
|
|
4801
5527
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
4802
5528
|
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
|
4803
5529
|
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
|
@@ -4808,6 +5534,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
4808
5534
|
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
|
4809
5535
|
bfloat16_support = true;
|
|
4810
5536
|
#endif
|
|
5537
|
+
} else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 &&
|
|
5538
|
+
!getenv("GGML_VK_DISABLE_DOT2")) {
|
|
5539
|
+
dot2_f16_support = true;
|
|
4811
5540
|
} else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) {
|
|
4812
5541
|
pipeline_executable_properties_support = true;
|
|
4813
5542
|
} else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
|
|
@@ -4955,6 +5684,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
4955
5684
|
#endif
|
|
4956
5685
|
device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
4957
5686
|
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
|
|
5687
|
+
#ifdef __APPLE__
|
|
5688
|
+
if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
|
5689
|
+
device->subgroup_shuffle = false;
|
|
5690
|
+
}
|
|
5691
|
+
#endif
|
|
4958
5692
|
device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
4959
5693
|
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered);
|
|
4960
5694
|
|
|
@@ -4981,8 +5715,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
4981
5715
|
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
|
4982
5716
|
|
|
4983
5717
|
// Try to find a non-graphics compute queue and transfer-focused queues
|
|
4984
|
-
//
|
|
4985
|
-
const
|
|
5718
|
+
// Allow overriding avoiding the graphics queue because it can increase performance on RADV
|
|
5719
|
+
const bool allow_graphics_queue = (getenv("GGML_VK_ALLOW_GRAPHICS_QUEUE") != nullptr);
|
|
5720
|
+
const vk::QueueFlagBits graphics_flag = allow_graphics_queue ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics;
|
|
4986
5721
|
const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, graphics_flag, -1, 1);
|
|
4987
5722
|
const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | graphics_flag, compute_queue_family_index, 1);
|
|
4988
5723
|
|
|
@@ -4998,7 +5733,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
4998
5733
|
} else {
|
|
4999
5734
|
device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
|
|
5000
5735
|
}
|
|
5001
|
-
vk::DeviceCreateInfo device_create_info;
|
|
5736
|
+
vk::DeviceCreateInfo device_create_info{};
|
|
5002
5737
|
std::vector<const char *> device_extensions;
|
|
5003
5738
|
vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
|
|
5004
5739
|
|
|
@@ -5074,6 +5809,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
5074
5809
|
}
|
|
5075
5810
|
#endif
|
|
5076
5811
|
|
|
5812
|
+
VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {};
|
|
5813
|
+
coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV;
|
|
5814
|
+
if (coopmat2_decode_vector_support) {
|
|
5815
|
+
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
|
|
5816
|
+
last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
|
|
5817
|
+
device_extensions.push_back(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME);
|
|
5818
|
+
}
|
|
5819
|
+
|
|
5077
5820
|
#if defined(VK_KHR_shader_bfloat16)
|
|
5078
5821
|
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
|
|
5079
5822
|
bfloat16_features.pNext = nullptr;
|
|
@@ -5101,6 +5844,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
5101
5844
|
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
|
|
5102
5845
|
}
|
|
5103
5846
|
|
|
5847
|
+
VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {};
|
|
5848
|
+
dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE;
|
|
5849
|
+
if (dot2_f16_support) {
|
|
5850
|
+
last_struct->pNext = (VkBaseOutStructure *)&dot2_features;
|
|
5851
|
+
last_struct = (VkBaseOutStructure *)&dot2_features;
|
|
5852
|
+
device_extensions.push_back("VK_VALVE_shader_mixed_float_dot_product");
|
|
5853
|
+
}
|
|
5854
|
+
|
|
5104
5855
|
VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {};
|
|
5105
5856
|
pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR;
|
|
5106
5857
|
if (pipeline_executable_properties_support) {
|
|
@@ -5135,6 +5886,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
5135
5886
|
device->bf16 = false;
|
|
5136
5887
|
#endif
|
|
5137
5888
|
|
|
5889
|
+
device->dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32;
|
|
5890
|
+
|
|
5138
5891
|
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
|
|
5139
5892
|
|
|
5140
5893
|
device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
|
|
@@ -5193,46 +5946,73 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
5193
5946
|
found_fp16_256 = false,
|
|
5194
5947
|
found_fp32_128 = false,
|
|
5195
5948
|
found_fp32_256 = false;
|
|
5949
|
+
bool found_bf16_128 = false,
|
|
5950
|
+
found_bf16_256 = false;
|
|
5196
5951
|
// need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
|
|
5197
5952
|
// with 32x16x16 and 256 with 32x32x16.
|
|
5198
5953
|
for (auto &prop : flexible_dimensions) {
|
|
5199
5954
|
if (prop.saturatingAccumulation == VK_FALSE &&
|
|
5200
|
-
prop.scope == VK_SCOPE_WORKGROUP_KHR
|
|
5201
|
-
|
|
5202
|
-
prop.
|
|
5203
|
-
|
|
5204
|
-
|
|
5205
|
-
prop.
|
|
5206
|
-
|
|
5207
|
-
|
|
5208
|
-
|
|
5209
|
-
prop.
|
|
5210
|
-
|
|
5955
|
+
prop.scope == VK_SCOPE_WORKGROUP_KHR) {
|
|
5956
|
+
|
|
5957
|
+
if (prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
|
5958
|
+
prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
|
5959
|
+
|
|
5960
|
+
if (prop.workgroupInvocations == 128 &&
|
|
5961
|
+
prop.MGranularity <= 32 &&
|
|
5962
|
+
prop.NGranularity <= 16 &&
|
|
5963
|
+
prop.KGranularity <= 16) {
|
|
5964
|
+
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
|
5965
|
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
|
5966
|
+
found_fp16_128 = true;
|
|
5967
|
+
}
|
|
5968
|
+
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
|
5969
|
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
|
|
5970
|
+
found_fp32_128 = true;
|
|
5971
|
+
}
|
|
5211
5972
|
}
|
|
5212
|
-
if (prop.
|
|
5213
|
-
prop.
|
|
5214
|
-
|
|
5973
|
+
if (prop.workgroupInvocations == 256 &&
|
|
5974
|
+
prop.MGranularity <= 32 &&
|
|
5975
|
+
prop.NGranularity <= 32 &&
|
|
5976
|
+
prop.KGranularity <= 16) {
|
|
5977
|
+
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
|
5978
|
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
|
5979
|
+
found_fp16_256 = true;
|
|
5980
|
+
}
|
|
5981
|
+
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
|
5982
|
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
|
|
5983
|
+
found_fp32_256 = true;
|
|
5984
|
+
}
|
|
5215
5985
|
}
|
|
5216
5986
|
}
|
|
5217
|
-
|
|
5218
|
-
|
|
5219
|
-
|
|
5220
|
-
prop.
|
|
5221
|
-
|
|
5222
|
-
|
|
5223
|
-
|
|
5987
|
+
|
|
5988
|
+
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
5989
|
+
if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
|
5990
|
+
prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
|
5991
|
+
prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
|
5992
|
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
|
|
5993
|
+
|
|
5994
|
+
if (prop.workgroupInvocations == 128 &&
|
|
5995
|
+
prop.MGranularity <= 32 &&
|
|
5996
|
+
prop.NGranularity <= 16 &&
|
|
5997
|
+
prop.KGranularity <= 16) {
|
|
5998
|
+
found_bf16_128 = true;
|
|
5224
5999
|
}
|
|
5225
|
-
if (prop.
|
|
5226
|
-
prop.
|
|
5227
|
-
|
|
6000
|
+
if (prop.workgroupInvocations == 256 &&
|
|
6001
|
+
prop.MGranularity <= 32 &&
|
|
6002
|
+
prop.NGranularity <= 32 &&
|
|
6003
|
+
prop.KGranularity <= 16) {
|
|
6004
|
+
found_bf16_256 = true;
|
|
5228
6005
|
}
|
|
5229
6006
|
}
|
|
6007
|
+
#endif
|
|
5230
6008
|
}
|
|
5231
6009
|
}
|
|
5232
6010
|
if (found_fp16_128 && found_fp16_256 &&
|
|
5233
6011
|
found_fp32_128 && found_fp32_256 &&
|
|
5234
6012
|
coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
|
|
5235
6013
|
device->coopmat2 = true;
|
|
6014
|
+
device->coopmat2_bf16_support = found_bf16_128 && found_bf16_256;
|
|
6015
|
+
device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector;
|
|
5236
6016
|
}
|
|
5237
6017
|
}
|
|
5238
6018
|
#endif
|
|
@@ -5367,12 +6147,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
5367
6147
|
#endif
|
|
5368
6148
|
device->name = GGML_VK_NAME + std::to_string(idx);
|
|
5369
6149
|
|
|
5370
|
-
device_create_info
|
|
5371
|
-
vk::DeviceCreateFlags()
|
|
5372
|
-
device_queue_create_infos
|
|
5373
|
-
|
|
5374
|
-
device_extensions
|
|
5375
|
-
};
|
|
6150
|
+
device_create_info
|
|
6151
|
+
.setFlags(vk::DeviceCreateFlags())
|
|
6152
|
+
.setQueueCreateInfos(device_queue_create_infos)
|
|
6153
|
+
.setPEnabledExtensionNames(device_extensions);
|
|
5376
6154
|
device_create_info.setPNext(&device_features2);
|
|
5377
6155
|
device->device = device->physical_device.createDevice(device_create_info);
|
|
5378
6156
|
|
|
@@ -5392,19 +6170,19 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
5392
6170
|
device->mul_mat_id_m[i] = true;
|
|
5393
6171
|
device->mul_mat_id_s[i] = true;
|
|
5394
6172
|
break;
|
|
5395
|
-
case VK_VENDOR_ID_INTEL:
|
|
5396
|
-
|
|
5397
|
-
|
|
5398
|
-
|
|
5399
|
-
|
|
5400
|
-
|
|
5401
|
-
|
|
5402
|
-
}
|
|
6173
|
+
case VK_VENDOR_ID_INTEL: {
|
|
6174
|
+
// Current Windows driver does not expose BF16 support.
|
|
6175
|
+
// We only want to use l_warptile if coopmat is available and is Xe2+
|
|
6176
|
+
const bool xe2_with_coopmat = device->coopmat_support && device->architecture == INTEL_XE2;
|
|
6177
|
+
const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && xe2_with_coopmat) : xe2_with_coopmat;
|
|
6178
|
+
device->mul_mat_l[i] = use_l_warptile;
|
|
6179
|
+
device->mul_mat_id_l[i] = use_l_warptile;
|
|
5403
6180
|
device->mul_mat_m[i] = true;
|
|
5404
6181
|
device->mul_mat_s[i] = true;
|
|
5405
6182
|
device->mul_mat_id_m[i] = true;
|
|
5406
6183
|
device->mul_mat_id_s[i] = true;
|
|
5407
6184
|
break;
|
|
6185
|
+
}
|
|
5408
6186
|
case VK_VENDOR_ID_APPLE:
|
|
5409
6187
|
device->mul_mat_l[i] = false;
|
|
5410
6188
|
device->mul_mat_m[i] = true;
|
|
@@ -5423,6 +6201,26 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
5423
6201
|
device->mul_mat_id_s[i] = true;
|
|
5424
6202
|
break;
|
|
5425
6203
|
}
|
|
6204
|
+
|
|
6205
|
+
#if VK_HEADER_VERSION >= 287
|
|
6206
|
+
// Honeykrisp driver for Asahi Linux doesn't report VK_VENDOR_ID_APPLE.
|
|
6207
|
+
// Check for Honeykrisp driver and force same configuration as the VK_VENDOR_ID_APPLE case.
|
|
6208
|
+
if (device->driver_id == vk::DriverId::eMesaHoneykrisp) {
|
|
6209
|
+
device->mul_mat_l[i] = false;
|
|
6210
|
+
device->mul_mat_m[i] = true;
|
|
6211
|
+
device->mul_mat_s[i] = false;
|
|
6212
|
+
device->mul_mat_id_l[i] = false;
|
|
6213
|
+
device->mul_mat_id_m[i] = true;
|
|
6214
|
+
device->mul_mat_id_s[i] = false;
|
|
6215
|
+
}
|
|
6216
|
+
#endif
|
|
6217
|
+
|
|
6218
|
+
device->mul_mat_l_int[i] = device->mul_mat_l[i];
|
|
6219
|
+
device->mul_mat_m_int[i] = device->mul_mat_m[i];
|
|
6220
|
+
device->mul_mat_s_int[i] = device->mul_mat_s[i];
|
|
6221
|
+
device->mul_mat_id_l_int[i] = device->mul_mat_id_l[i];
|
|
6222
|
+
device->mul_mat_id_m_int[i] = device->mul_mat_id_m[i];
|
|
6223
|
+
device->mul_mat_id_s_int[i] = device->mul_mat_id_s[i];
|
|
5426
6224
|
}
|
|
5427
6225
|
|
|
5428
6226
|
|
|
@@ -5443,11 +6241,18 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
5443
6241
|
|
|
5444
6242
|
ggml_vk_load_shaders(device);
|
|
5445
6243
|
|
|
6244
|
+
// Prefer a dedicated transfer queue on AMD dGPUs (non-GCN) when graphics queue use is disabled.
|
|
6245
|
+
const bool prefers_transfer_queue =
|
|
6246
|
+
device->vendor_id == VK_VENDOR_ID_AMD &&
|
|
6247
|
+
device->architecture != AMD_GCN &&
|
|
6248
|
+
!device->uma &&
|
|
6249
|
+
!allow_graphics_queue;
|
|
6250
|
+
|
|
5446
6251
|
if (!device->single_queue) {
|
|
5447
6252
|
const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
|
|
5448
6253
|
ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
|
|
5449
6254
|
|
|
5450
|
-
device->async_use_transfer_queue = (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr);
|
|
6255
|
+
device->async_use_transfer_queue = prefers_transfer_queue || (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr);
|
|
5451
6256
|
} else {
|
|
5452
6257
|
// TODO: Use pointer or reference to avoid copy
|
|
5453
6258
|
device->transfer_queue.copyFrom(device->compute_queue);
|
|
@@ -5507,8 +6312,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
5507
6312
|
bool fp16_compute = false;
|
|
5508
6313
|
bool coopmat_support = false;
|
|
5509
6314
|
bool coopmat2_support = false;
|
|
6315
|
+
bool coopmat2_decode_vector_support = false;
|
|
5510
6316
|
bool integer_dot_product = false;
|
|
5511
6317
|
bool bfloat16_support = false;
|
|
6318
|
+
bool dot2_f16_support = false;
|
|
5512
6319
|
|
|
5513
6320
|
for (auto properties : ext_props) {
|
|
5514
6321
|
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
|
@@ -5525,6 +6332,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
5525
6332
|
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
|
5526
6333
|
coopmat2_support = true;
|
|
5527
6334
|
#endif
|
|
6335
|
+
} else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 &&
|
|
6336
|
+
!getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) {
|
|
6337
|
+
coopmat2_decode_vector_support = true;
|
|
5528
6338
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
5529
6339
|
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
|
5530
6340
|
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
|
@@ -5535,6 +6345,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
5535
6345
|
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
|
5536
6346
|
bfloat16_support = true;
|
|
5537
6347
|
#endif
|
|
6348
|
+
} else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 &&
|
|
6349
|
+
!getenv("GGML_VK_DISABLE_DOT2")) {
|
|
6350
|
+
dot2_f16_support = true;
|
|
5538
6351
|
}
|
|
5539
6352
|
}
|
|
5540
6353
|
|
|
@@ -5609,6 +6422,29 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
5609
6422
|
}
|
|
5610
6423
|
#endif
|
|
5611
6424
|
|
|
6425
|
+
#if defined(VK_NV_cooperative_matrix2)
|
|
6426
|
+
VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
|
|
6427
|
+
coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
|
|
6428
|
+
if (coopmat2_support) {
|
|
6429
|
+
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;
|
|
6430
|
+
last_struct = (VkBaseOutStructure *)&coopmat2_features;
|
|
6431
|
+
}
|
|
6432
|
+
#endif
|
|
6433
|
+
|
|
6434
|
+
VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {};
|
|
6435
|
+
coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV;
|
|
6436
|
+
if (coopmat2_decode_vector_support) {
|
|
6437
|
+
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
|
|
6438
|
+
last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
|
|
6439
|
+
}
|
|
6440
|
+
|
|
6441
|
+
VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {};
|
|
6442
|
+
dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE;
|
|
6443
|
+
if (dot2_f16_support) {
|
|
6444
|
+
last_struct->pNext = (VkBaseOutStructure *)&dot2_features;
|
|
6445
|
+
last_struct = (VkBaseOutStructure *)&dot2_features;
|
|
6446
|
+
}
|
|
6447
|
+
|
|
5612
6448
|
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
|
5613
6449
|
|
|
5614
6450
|
fp16 = fp16 && vk12_features.shaderFloat16;
|
|
@@ -5633,11 +6469,34 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
5633
6469
|
#endif
|
|
5634
6470
|
&& ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
|
|
5635
6471
|
|
|
5636
|
-
|
|
6472
|
+
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
6473
|
+
coopmat2_support = coopmat2_support &&
|
|
6474
|
+
coopmat2_features.cooperativeMatrixWorkgroupScope &&
|
|
6475
|
+
coopmat2_features.cooperativeMatrixFlexibleDimensions &&
|
|
6476
|
+
coopmat2_features.cooperativeMatrixReductions &&
|
|
6477
|
+
coopmat2_features.cooperativeMatrixConversions &&
|
|
6478
|
+
coopmat2_features.cooperativeMatrixPerElementOperations &&
|
|
6479
|
+
coopmat2_features.cooperativeMatrixTensorAddressing &&
|
|
6480
|
+
coopmat2_features.cooperativeMatrixBlockLoads;
|
|
6481
|
+
#else
|
|
6482
|
+
coopmat2_support = false;
|
|
6483
|
+
#endif
|
|
6484
|
+
|
|
6485
|
+
coopmat2_decode_vector_support = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector;
|
|
6486
|
+
#if !defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
|
|
6487
|
+
coopmat2_decode_vector_support = false;
|
|
6488
|
+
#endif
|
|
6489
|
+
|
|
6490
|
+
std::string matrix_cores = coopmat2_support ? (coopmat2_decode_vector_support ? "NV_coopmat2v" : "NV_coopmat2")
|
|
6491
|
+
: coopmat_support ? "KHR_coopmat"
|
|
6492
|
+
: "none";
|
|
6493
|
+
|
|
6494
|
+
bool dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32;
|
|
6495
|
+
const char *fp16_str = fp16 ? (dot2_f16 ? "dot2" : "1") : "0";
|
|
5637
6496
|
|
|
5638
6497
|
std::string device_name = props2.properties.deviceName.data();
|
|
5639
|
-
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %
|
|
5640
|
-
idx, device_name.c_str(), driver_props.driverName.data(), uma,
|
|
6498
|
+
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %s | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
|
|
6499
|
+
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16_str, bf16, subgroup_size,
|
|
5641
6500
|
props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
|
|
5642
6501
|
|
|
5643
6502
|
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
|
@@ -5953,6 +6812,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
|
|
5953
6812
|
VK_LOG_DEBUG("ggml_vk_get_to_fp16()");
|
|
5954
6813
|
switch (type) {
|
|
5955
6814
|
case GGML_TYPE_F32:
|
|
6815
|
+
case GGML_TYPE_Q1_0:
|
|
5956
6816
|
case GGML_TYPE_Q4_0:
|
|
5957
6817
|
case GGML_TYPE_Q4_1:
|
|
5958
6818
|
case GGML_TYPE_Q5_0:
|
|
@@ -5973,6 +6833,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
|
|
5973
6833
|
case GGML_TYPE_IQ4_XS:
|
|
5974
6834
|
case GGML_TYPE_IQ4_NL:
|
|
5975
6835
|
case GGML_TYPE_MXFP4:
|
|
6836
|
+
case GGML_TYPE_NVFP4:
|
|
5976
6837
|
break;
|
|
5977
6838
|
default:
|
|
5978
6839
|
return nullptr;
|
|
@@ -6024,6 +6885,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
|
6024
6885
|
}
|
|
6025
6886
|
|
|
6026
6887
|
switch (src0_type) {
|
|
6888
|
+
case GGML_TYPE_Q1_0:
|
|
6027
6889
|
case GGML_TYPE_Q4_0:
|
|
6028
6890
|
case GGML_TYPE_Q4_1:
|
|
6029
6891
|
case GGML_TYPE_Q5_0:
|
|
@@ -6044,6 +6906,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
|
6044
6906
|
case GGML_TYPE_IQ4_XS:
|
|
6045
6907
|
case GGML_TYPE_IQ4_NL:
|
|
6046
6908
|
case GGML_TYPE_MXFP4:
|
|
6909
|
+
case GGML_TYPE_NVFP4:
|
|
6047
6910
|
break;
|
|
6048
6911
|
default:
|
|
6049
6912
|
return nullptr;
|
|
@@ -6089,6 +6952,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|
|
6089
6952
|
case GGML_TYPE_F32:
|
|
6090
6953
|
case GGML_TYPE_F16:
|
|
6091
6954
|
case GGML_TYPE_BF16:
|
|
6955
|
+
case GGML_TYPE_Q1_0:
|
|
6092
6956
|
case GGML_TYPE_Q4_0:
|
|
6093
6957
|
case GGML_TYPE_Q4_1:
|
|
6094
6958
|
case GGML_TYPE_Q5_0:
|
|
@@ -6109,6 +6973,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|
|
6109
6973
|
case GGML_TYPE_IQ4_XS:
|
|
6110
6974
|
case GGML_TYPE_IQ4_NL:
|
|
6111
6975
|
case GGML_TYPE_MXFP4:
|
|
6976
|
+
case GGML_TYPE_NVFP4:
|
|
6112
6977
|
break;
|
|
6113
6978
|
default:
|
|
6114
6979
|
return nullptr;
|
|
@@ -6179,6 +7044,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|
|
6179
7044
|
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
|
|
6180
7045
|
|
|
6181
7046
|
switch (src0_type) {
|
|
7047
|
+
case GGML_TYPE_Q1_0:
|
|
6182
7048
|
case GGML_TYPE_Q4_0:
|
|
6183
7049
|
case GGML_TYPE_Q4_1:
|
|
6184
7050
|
case GGML_TYPE_Q5_0:
|
|
@@ -6199,6 +7065,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|
|
6199
7065
|
case GGML_TYPE_IQ4_XS:
|
|
6200
7066
|
case GGML_TYPE_IQ4_NL:
|
|
6201
7067
|
case GGML_TYPE_MXFP4:
|
|
7068
|
+
case GGML_TYPE_NVFP4:
|
|
6202
7069
|
break;
|
|
6203
7070
|
default:
|
|
6204
7071
|
return nullptr;
|
|
@@ -6247,6 +7114,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
|
|
6247
7114
|
case GGML_TYPE_F32:
|
|
6248
7115
|
case GGML_TYPE_F16:
|
|
6249
7116
|
case GGML_TYPE_BF16:
|
|
7117
|
+
case GGML_TYPE_Q1_0:
|
|
6250
7118
|
case GGML_TYPE_Q4_0:
|
|
6251
7119
|
case GGML_TYPE_Q4_1:
|
|
6252
7120
|
case GGML_TYPE_Q5_0:
|
|
@@ -6267,6 +7135,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
|
|
6267
7135
|
case GGML_TYPE_IQ4_XS:
|
|
6268
7136
|
case GGML_TYPE_IQ4_NL:
|
|
6269
7137
|
case GGML_TYPE_MXFP4:
|
|
7138
|
+
case GGML_TYPE_NVFP4:
|
|
6270
7139
|
break;
|
|
6271
7140
|
default:
|
|
6272
7141
|
return nullptr;
|
|
@@ -6313,7 +7182,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
|
|
|
6313
7182
|
return nullptr;
|
|
6314
7183
|
}
|
|
6315
7184
|
|
|
6316
|
-
std::lock_guard<std::
|
|
7185
|
+
std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex);
|
|
6317
7186
|
device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
|
|
6318
7187
|
|
|
6319
7188
|
return buf->ptr;
|
|
@@ -6324,7 +7193,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
|
|
|
6324
7193
|
return;
|
|
6325
7194
|
}
|
|
6326
7195
|
VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
|
|
6327
|
-
std::lock_guard<std::
|
|
7196
|
+
std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex);
|
|
6328
7197
|
|
|
6329
7198
|
vk_buffer buf;
|
|
6330
7199
|
size_t index;
|
|
@@ -6348,7 +7217,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
|
|
|
6348
7217
|
}
|
|
6349
7218
|
|
|
6350
7219
|
static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
|
|
6351
|
-
std::
|
|
7220
|
+
std::shared_lock<std::shared_mutex> guard(device->pinned_memory_mutex);
|
|
6352
7221
|
buf = nullptr;
|
|
6353
7222
|
buf_offset = 0;
|
|
6354
7223
|
for (size_t i = 0; i < device->pinned_memory.size(); i++) {
|
|
@@ -6392,6 +7261,7 @@ static vk_subbuffer ggml_vk_tensor_subbuffer(
|
|
|
6392
7261
|
static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) {
|
|
6393
7262
|
for (auto& cmd_buffer : pool.cmd_buffers) {
|
|
6394
7263
|
if (!cmd_buffer.in_use) {
|
|
7264
|
+
cmd_buffer.use_counter++;
|
|
6395
7265
|
cmd_buffer.in_use = true;
|
|
6396
7266
|
return &cmd_buffer;
|
|
6397
7267
|
}
|
|
@@ -6468,13 +7338,6 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
|
|
|
6468
7338
|
subctx->s->buffer->buf.dispatch(wg0, wg1, wg2);
|
|
6469
7339
|
}
|
|
6470
7340
|
|
|
6471
|
-
static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
|
|
6472
|
-
s.buffer->buf.end();
|
|
6473
|
-
|
|
6474
|
-
s.wait_semaphores = std::move(wait_semaphores);
|
|
6475
|
-
s.signal_semaphores = std::move(signal_semaphores);
|
|
6476
|
-
}
|
|
6477
|
-
|
|
6478
7341
|
static void ggml_vk_ctx_end(vk_context& ctx) {
|
|
6479
7342
|
VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")");
|
|
6480
7343
|
if (ctx->s == nullptr) {
|
|
@@ -6496,14 +7359,15 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
|
|
|
6496
7359
|
}
|
|
6497
7360
|
|
|
6498
7361
|
static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) {
|
|
7362
|
+
vk_context result;
|
|
6499
7363
|
if (!ctx->compute_ctx.expired()) {
|
|
6500
|
-
|
|
6501
|
-
}
|
|
6502
|
-
|
|
6503
|
-
vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
|
7364
|
+
result = ctx->compute_ctx.lock();
|
|
7365
|
+
} else {
|
|
7366
|
+
result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
|
6504
7367
|
|
|
6505
|
-
|
|
6506
|
-
|
|
7368
|
+
ctx->compute_ctx = result;
|
|
7369
|
+
ggml_vk_ctx_begin(ctx->device, result);
|
|
7370
|
+
}
|
|
6507
7371
|
|
|
6508
7372
|
if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
|
|
6509
7373
|
result->s->wait_semaphores.push_back(ctx->transfer_semaphore);
|
|
@@ -6626,7 +7490,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
|
|
|
6626
7490
|
const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
|
|
6627
7491
|
const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
|
|
6628
7492
|
for (uint64_t i0 = 0; i0 < ne0; i0++) {
|
|
6629
|
-
slices.push_back({ s_off +
|
|
7493
|
+
slices.push_back({ s_off + i0*nb0, d_off + i0*dstnb0, dstnb0 });
|
|
6630
7494
|
}
|
|
6631
7495
|
}
|
|
6632
7496
|
}
|
|
@@ -6674,7 +7538,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
|
|
|
6674
7538
|
}
|
|
6675
7539
|
}
|
|
6676
7540
|
|
|
6677
|
-
static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
|
|
7541
|
+
static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
|
|
6678
7542
|
VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
|
|
6679
7543
|
// Check if src is pinned memory
|
|
6680
7544
|
vk_buffer buf = nullptr;
|
|
@@ -6684,7 +7548,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
|
|
|
6684
7548
|
if (buf != nullptr) {
|
|
6685
7549
|
// Memory is pinned, use as staging buffer
|
|
6686
7550
|
std::vector<vk::BufferCopy> slices(1);
|
|
6687
|
-
if (width == spitch) {
|
|
7551
|
+
if (width == spitch && width == dpitch) {
|
|
6688
7552
|
// Only do single write if stride is equal
|
|
6689
7553
|
slices[0].srcOffset = buf_offset;
|
|
6690
7554
|
slices[0].dstOffset = offset;
|
|
@@ -6693,7 +7557,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
|
|
|
6693
7557
|
slices.resize(height);
|
|
6694
7558
|
for (size_t i = 0; i < height; i++) {
|
|
6695
7559
|
slices[i].srcOffset = buf_offset + i * spitch;
|
|
6696
|
-
slices[i].dstOffset = offset + i *
|
|
7560
|
+
slices[i].dstOffset = offset + i * dpitch;
|
|
6697
7561
|
slices[i].size = width;
|
|
6698
7562
|
}
|
|
6699
7563
|
}
|
|
@@ -6710,21 +7574,30 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
|
|
|
6710
7574
|
}
|
|
6711
7575
|
|
|
6712
7576
|
// Staging buffer required
|
|
6713
|
-
const size_t
|
|
6714
|
-
ggml_vk_ensure_sync_staging_buffer(dst->device,
|
|
7577
|
+
const size_t staging_size = width * height;
|
|
7578
|
+
ggml_vk_ensure_sync_staging_buffer(dst->device, staging_size);
|
|
6715
7579
|
|
|
6716
7580
|
vk_buffer& staging_buffer = dst->device->sync_staging;
|
|
6717
7581
|
|
|
6718
|
-
|
|
6719
|
-
|
|
6720
|
-
|
|
6721
|
-
|
|
7582
|
+
std::vector<vk::BufferCopy> slices(1);
|
|
7583
|
+
if (width == dpitch) {
|
|
7584
|
+
slices[0].srcOffset = 0;
|
|
7585
|
+
slices[0].dstOffset = offset;
|
|
7586
|
+
slices[0].size = staging_size;
|
|
7587
|
+
} else {
|
|
7588
|
+
slices.resize(height);
|
|
7589
|
+
for (size_t i = 0; i < height; i++) {
|
|
7590
|
+
slices[i].srcOffset = i * width;
|
|
7591
|
+
slices[i].dstOffset = offset + i * dpitch;
|
|
7592
|
+
slices[i].size = width;
|
|
7593
|
+
}
|
|
7594
|
+
}
|
|
6722
7595
|
|
|
6723
7596
|
ggml_vk_sync_buffers(nullptr, subctx);
|
|
6724
|
-
|
|
7597
|
+
subctx->s->buffer->buf.copyBuffer((VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, slices);
|
|
6725
7598
|
|
|
6726
7599
|
if (width == spitch) {
|
|
6727
|
-
deferred_memcpy((uint8_t *)staging_buffer->ptr, src,
|
|
7600
|
+
deferred_memcpy((uint8_t *)staging_buffer->ptr, src, staging_size, &subctx->in_memcpys);
|
|
6728
7601
|
} else {
|
|
6729
7602
|
for (size_t i = 0; i < height; i++) {
|
|
6730
7603
|
deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
|
|
@@ -6735,24 +7608,28 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
|
|
|
6735
7608
|
|
|
6736
7609
|
static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
|
|
6737
7610
|
VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
|
|
6738
|
-
return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
|
|
7611
|
+
return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, size, 1, sync_staging);
|
|
6739
7612
|
}
|
|
6740
7613
|
|
|
6741
|
-
static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {
|
|
7614
|
+
static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height) {
|
|
6742
7615
|
VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")");
|
|
6743
7616
|
// Buffer is already mapped
|
|
6744
7617
|
if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
|
|
6745
7618
|
GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
|
|
6746
7619
|
|
|
6747
|
-
|
|
6748
|
-
memcpy((uint8_t *)dst->ptr + offset
|
|
7620
|
+
if (width == spitch && width == dpitch) {
|
|
7621
|
+
memcpy((uint8_t *)dst->ptr + offset, src, width * height);
|
|
7622
|
+
} else {
|
|
7623
|
+
for (size_t i = 0; i < height; i++) {
|
|
7624
|
+
memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width);
|
|
7625
|
+
}
|
|
6749
7626
|
}
|
|
6750
7627
|
} else {
|
|
6751
7628
|
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
|
|
6752
7629
|
|
|
6753
7630
|
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
|
|
6754
7631
|
ggml_vk_ctx_begin(dst->device, subctx);
|
|
6755
|
-
bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
|
|
7632
|
+
bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, dpitch, width, height, true);
|
|
6756
7633
|
GGML_ASSERT(ret);
|
|
6757
7634
|
ggml_vk_ctx_end(subctx);
|
|
6758
7635
|
|
|
@@ -6773,7 +7650,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
|
|
|
6773
7650
|
|
|
6774
7651
|
static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) {
|
|
6775
7652
|
VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")");
|
|
6776
|
-
ggml_vk_buffer_write_2d(dst, offset, src,
|
|
7653
|
+
ggml_vk_buffer_write_2d(dst, offset, src, size, size, size, 1);
|
|
6777
7654
|
}
|
|
6778
7655
|
|
|
6779
7656
|
static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
|
|
@@ -6819,15 +7696,35 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
|
|
|
6819
7696
|
}
|
|
6820
7697
|
|
|
6821
7698
|
// Fall back to staging buffer
|
|
6822
|
-
const size_t
|
|
6823
|
-
ggml_vk_ensure_sync_staging_buffer(src->device,
|
|
7699
|
+
const size_t staging_size = width * height;
|
|
7700
|
+
ggml_vk_ensure_sync_staging_buffer(src->device, staging_size);
|
|
6824
7701
|
|
|
6825
7702
|
vk_buffer& staging_buffer = src->device->sync_staging;
|
|
6826
7703
|
|
|
7704
|
+
std::vector<vk::BufferCopy> staging_slices(1);
|
|
7705
|
+
if (width == spitch) {
|
|
7706
|
+
staging_slices[0].srcOffset = offset;
|
|
7707
|
+
staging_slices[0].dstOffset = 0;
|
|
7708
|
+
staging_slices[0].size = staging_size;
|
|
7709
|
+
} else {
|
|
7710
|
+
staging_slices.resize(height);
|
|
7711
|
+
for (size_t i = 0; i < height; i++) {
|
|
7712
|
+
staging_slices[i].srcOffset = offset + i * spitch;
|
|
7713
|
+
staging_slices[i].dstOffset = i * width;
|
|
7714
|
+
staging_slices[i].size = width;
|
|
7715
|
+
}
|
|
7716
|
+
}
|
|
7717
|
+
|
|
6827
7718
|
ggml_vk_sync_buffers(nullptr, subctx);
|
|
6828
|
-
subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer,
|
|
7719
|
+
subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, staging_slices);
|
|
6829
7720
|
|
|
6830
|
-
|
|
7721
|
+
if (width == dpitch) {
|
|
7722
|
+
deferred_memcpy(dst, staging_buffer->ptr, staging_size, &subctx->out_memcpys);
|
|
7723
|
+
} else {
|
|
7724
|
+
for (size_t i = 0; i < height; i++) {
|
|
7725
|
+
deferred_memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) staging_buffer->ptr + i * width, width, &subctx->out_memcpys);
|
|
7726
|
+
}
|
|
7727
|
+
}
|
|
6831
7728
|
return true;
|
|
6832
7729
|
}
|
|
6833
7730
|
|
|
@@ -6835,8 +7732,8 @@ static bool ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t
|
|
|
6835
7732
|
return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging);
|
|
6836
7733
|
}
|
|
6837
7734
|
|
|
6838
|
-
static void
|
|
6839
|
-
VK_LOG_DEBUG("
|
|
7735
|
+
static void ggml_vk_buffer_read_2d(vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height) {
|
|
7736
|
+
VK_LOG_DEBUG("ggml_vk_buffer_read_2d(" << src->buffer << ", " << offset << ", " << width << ", " << height << ")");
|
|
6840
7737
|
|
|
6841
7738
|
// If the device is not an UMA device the memory is host-accessible through rebar. While writing
|
|
6842
7739
|
// through PCIe is sufficient fast reading back data from PCIe is slower than going through
|
|
@@ -6844,18 +7741,24 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
|
|
|
6844
7741
|
if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) {
|
|
6845
7742
|
GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
|
|
6846
7743
|
|
|
6847
|
-
|
|
7744
|
+
if (width == spitch && width == dpitch) {
|
|
7745
|
+
memcpy(dst, (const uint8_t *) src->ptr + offset, width * height);
|
|
7746
|
+
} else {
|
|
7747
|
+
for (size_t i = 0; i < height; i++) {
|
|
7748
|
+
memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width);
|
|
7749
|
+
}
|
|
7750
|
+
}
|
|
6848
7751
|
} else {
|
|
6849
7752
|
std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
|
|
6850
7753
|
|
|
6851
7754
|
vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
|
|
6852
7755
|
ggml_vk_ctx_begin(src->device, subctx);
|
|
6853
|
-
bool ret =
|
|
7756
|
+
bool ret = ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, spitch, dpitch, width, height, true);
|
|
6854
7757
|
GGML_ASSERT(ret);
|
|
6855
7758
|
ggml_vk_ctx_end(subctx);
|
|
6856
7759
|
|
|
6857
7760
|
ggml_vk_submit(subctx, src->device->fence);
|
|
6858
|
-
VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "
|
|
7761
|
+
VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read_2d waitForFences");
|
|
6859
7762
|
src->device->device.resetFences({ src->device->fence });
|
|
6860
7763
|
ggml_vk_queue_command_pools_cleanup(src->device);
|
|
6861
7764
|
|
|
@@ -6865,6 +7768,11 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
|
|
|
6865
7768
|
}
|
|
6866
7769
|
}
|
|
6867
7770
|
|
|
7771
|
+
static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
|
|
7772
|
+
VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
|
|
7773
|
+
ggml_vk_buffer_read_2d(src, offset, dst, size, size, size, 1);
|
|
7774
|
+
}
|
|
7775
|
+
|
|
6868
7776
|
static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
|
|
6869
7777
|
VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
|
|
6870
7778
|
// Make sure both buffers are on same device
|
|
@@ -6896,7 +7804,7 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
|
|
|
6896
7804
|
// Copy to src staging buffer
|
|
6897
7805
|
ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
|
|
6898
7806
|
// Copy to dst buffer
|
|
6899
|
-
|
|
7807
|
+
ggml_vk_buffer_write(dst, dst_offset, src->device->sync_staging->ptr, size);
|
|
6900
7808
|
}
|
|
6901
7809
|
}
|
|
6902
7810
|
|
|
@@ -6979,6 +7887,13 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m,
|
|
|
6979
7887
|
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
|
|
6980
7888
|
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
|
6981
7889
|
|
|
7890
|
+
// The q8_1 (integer dot) mmq path uses a different shader with its own
|
|
7891
|
+
// shared-memory layout, so use the int-specific availability flags.
|
|
7892
|
+
const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1);
|
|
7893
|
+
const bool mm_l = is_q8_1 ? ctx->device->mul_mat_l_int[src0_type] : ctx->device->mul_mat_l[src0_type];
|
|
7894
|
+
const bool mm_m = is_q8_1 ? ctx->device->mul_mat_m_int[src0_type] : ctx->device->mul_mat_m[src0_type];
|
|
7895
|
+
const bool mm_s = is_q8_1 ? ctx->device->mul_mat_s_int[src0_type] : ctx->device->mul_mat_s[src0_type];
|
|
7896
|
+
|
|
6982
7897
|
if (ctx->device->coopmat2) {
|
|
6983
7898
|
const uint32_t shader_core_count = ctx->device->shader_core_count;
|
|
6984
7899
|
const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
|
|
@@ -6995,26 +7910,24 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
|
|
|
6995
7910
|
// split_k==3 with large tiles likely better than medium tiles with no split_k.
|
|
6996
7911
|
(tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
|
|
6997
7912
|
|
|
6998
|
-
if ((
|
|
7913
|
+
if ((mm_l && (n > crossover_large && prefer_large)) || (!mm_m && !mm_s)) {
|
|
6999
7914
|
return aligned ? mmp->a_l : mmp->l;
|
|
7000
7915
|
}
|
|
7001
7916
|
// Use medium shader when the N dimension is greater than the small shader's tile size
|
|
7002
7917
|
uint32_t crossover_medium = mmp->s->wg_denoms[1];
|
|
7003
|
-
if ((
|
|
7918
|
+
if ((mm_m && (n > crossover_medium)) || !mm_s) {
|
|
7004
7919
|
return aligned ? mmp->a_m : mmp->m;
|
|
7005
7920
|
}
|
|
7006
7921
|
return aligned ? mmp->a_s : mmp->s;
|
|
7007
7922
|
}
|
|
7008
7923
|
|
|
7009
|
-
if ((
|
|
7924
|
+
if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) {
|
|
7010
7925
|
return aligned ? mmp->a_s : mmp->s;
|
|
7011
7926
|
}
|
|
7012
|
-
if ((
|
|
7927
|
+
if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) {
|
|
7013
7928
|
return aligned ? mmp->a_m : mmp->m;
|
|
7014
7929
|
}
|
|
7015
7930
|
return aligned ? mmp->a_l : mmp->l;
|
|
7016
|
-
|
|
7017
|
-
GGML_UNUSED(src1_type);
|
|
7018
7931
|
}
|
|
7019
7932
|
|
|
7020
7933
|
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
|
|
@@ -7071,35 +7984,42 @@ static void ggml_vk_matmul(
|
|
|
7071
7984
|
ctx->prealloc_split_k_need_sync = true;
|
|
7072
7985
|
}
|
|
7073
7986
|
|
|
7074
|
-
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
|
|
7075
|
-
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
|
7987
|
+
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
|
|
7988
|
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
|
7989
|
+
|
|
7990
|
+
// The q8_1 (integer dot) mmq path uses a different shader with its own
|
|
7991
|
+
// shared-memory layout, so use the int-specific availability flags.
|
|
7992
|
+
const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1);
|
|
7993
|
+
const bool mm_l = is_q8_1 ? ctx->device->mul_mat_id_l_int[src0_type] : ctx->device->mul_mat_id_l[src0_type];
|
|
7994
|
+
const bool mm_m = is_q8_1 ? ctx->device->mul_mat_id_m_int[src0_type] : ctx->device->mul_mat_id_m[src0_type];
|
|
7995
|
+
const bool mm_s = is_q8_1 ? ctx->device->mul_mat_id_s_int[src0_type] : ctx->device->mul_mat_id_s[src0_type];
|
|
7076
7996
|
|
|
7077
7997
|
if (ctx->device->coopmat2) {
|
|
7078
7998
|
// Use large shader when the N dimension is greater than the medium shader's tile size
|
|
7079
7999
|
uint32_t crossover_large = mmp->m->wg_denoms[1];
|
|
7080
|
-
if ((
|
|
8000
|
+
if ((mm_l && (n > crossover_large)) || (!mm_m && !mm_s)) {
|
|
7081
8001
|
return aligned ? mmp->a_l : mmp->l;
|
|
7082
8002
|
}
|
|
7083
8003
|
// Use medium shader when the N dimension is greater than the small shader's tile size
|
|
7084
8004
|
uint32_t crossover_medium = mmp->s->wg_denoms[1];
|
|
7085
|
-
if ((
|
|
8005
|
+
if ((mm_m && (n > crossover_medium)) || !mm_s) {
|
|
7086
8006
|
return aligned ? mmp->a_m : mmp->m;
|
|
7087
8007
|
}
|
|
7088
8008
|
return aligned ? mmp->a_s : mmp->s;
|
|
7089
8009
|
}
|
|
7090
8010
|
|
|
7091
|
-
if ((
|
|
8011
|
+
if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) {
|
|
7092
8012
|
return aligned ? mmp->a_s : mmp->s;
|
|
7093
8013
|
}
|
|
7094
|
-
if ((
|
|
8014
|
+
if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) {
|
|
7095
8015
|
return aligned ? mmp->a_m : mmp->m;
|
|
7096
8016
|
}
|
|
7097
8017
|
return aligned ? mmp->a_l : mmp->l;
|
|
7098
8018
|
}
|
|
7099
8019
|
|
|
7100
|
-
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
|
|
7101
|
-
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
|
|
7102
|
-
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
|
|
8020
|
+
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
|
|
8021
|
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
|
8022
|
+
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
|
|
7103
8023
|
}
|
|
7104
8024
|
|
|
7105
8025
|
static void ggml_vk_matmul_id(
|
|
@@ -7176,6 +8096,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
7176
8096
|
return ctx->device->pipeline_cpy_f32_bf16;
|
|
7177
8097
|
}
|
|
7178
8098
|
}
|
|
8099
|
+
if (src->type == GGML_TYPE_BF16 && to == GGML_TYPE_F32) {
|
|
8100
|
+
if (contig) {
|
|
8101
|
+
return ctx->device->pipeline_contig_cpy_bf16_f32;
|
|
8102
|
+
} else {
|
|
8103
|
+
return ctx->device->pipeline_cpy_bf16_f32;
|
|
8104
|
+
}
|
|
8105
|
+
}
|
|
7179
8106
|
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) {
|
|
7180
8107
|
if (contig) {
|
|
7181
8108
|
return ctx->device->pipeline_contig_cpy_f32_i32;
|
|
@@ -7192,6 +8119,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
7192
8119
|
}
|
|
7193
8120
|
if (src->type == GGML_TYPE_F32) {
|
|
7194
8121
|
switch (to) {
|
|
8122
|
+
case GGML_TYPE_Q1_0:
|
|
7195
8123
|
case GGML_TYPE_Q4_0:
|
|
7196
8124
|
case GGML_TYPE_Q4_1:
|
|
7197
8125
|
case GGML_TYPE_Q5_0:
|
|
@@ -7206,6 +8134,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
7206
8134
|
|
|
7207
8135
|
if (to == GGML_TYPE_F32) {
|
|
7208
8136
|
switch (src->type) {
|
|
8137
|
+
case GGML_TYPE_Q1_0:
|
|
7209
8138
|
case GGML_TYPE_Q4_0:
|
|
7210
8139
|
case GGML_TYPE_Q4_1:
|
|
7211
8140
|
case GGML_TYPE_Q5_0:
|
|
@@ -7272,6 +8201,40 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
|
|
|
7272
8201
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
7273
8202
|
}
|
|
7274
8203
|
|
|
8204
|
+
// Copy/convert tensor into a caller-defined dense layout. Destination strides
|
|
8205
|
+
// are in output elements, not bytes.
|
|
8206
|
+
static void ggml_vk_cpy_to_strided(
|
|
8207
|
+
ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor,
|
|
8208
|
+
const vk_subbuffer & in, const vk_subbuffer & out,
|
|
8209
|
+
uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13) {
|
|
8210
|
+
VK_LOG_DEBUG("ggml_vk_cpy_to_strided((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
|
|
8211
|
+
std::cerr << "dst_nb=(" << nb10 << ", " << nb11 << ", " << nb12 << ", " << nb13 << "), buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")");
|
|
8212
|
+
const int tensor_type_size = ggml_type_size(tensor->type);
|
|
8213
|
+
|
|
8214
|
+
const uint32_t ne = ggml_nelements(tensor);
|
|
8215
|
+
std::array<uint32_t, 3> elements;
|
|
8216
|
+
|
|
8217
|
+
if (ne > 262144) {
|
|
8218
|
+
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
|
8219
|
+
} else if (ne > 512) {
|
|
8220
|
+
elements = { 512, CEIL_DIV(ne, 512), 1 };
|
|
8221
|
+
} else {
|
|
8222
|
+
elements = { ne, 1, 1 };
|
|
8223
|
+
}
|
|
8224
|
+
|
|
8225
|
+
vk_op_unary_push_constants pc = {
|
|
8226
|
+
(uint32_t)ne,
|
|
8227
|
+
(uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
|
|
8228
|
+
(uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], nb10, nb11, nb12, nb13,
|
|
8229
|
+
0,
|
|
8230
|
+
0.0f, 0.0f,
|
|
8231
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
8232
|
+
};
|
|
8233
|
+
init_pushconst_fastdiv(pc);
|
|
8234
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
|
|
8235
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
8236
|
+
}
|
|
8237
|
+
|
|
7275
8238
|
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
|
|
7276
8239
|
switch(type) {
|
|
7277
8240
|
case GGML_TYPE_Q8_1:
|
|
@@ -7393,10 +8356,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
7393
8356
|
// Not implemented
|
|
7394
8357
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
|
7395
8358
|
|
|
7396
|
-
const
|
|
8359
|
+
const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type);
|
|
8360
|
+
|
|
8361
|
+
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, effective_src1_type));
|
|
7397
8362
|
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
|
|
7398
8363
|
|
|
7399
|
-
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type,
|
|
8364
|
+
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type);
|
|
7400
8365
|
|
|
7401
8366
|
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
|
|
7402
8367
|
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
|
|
@@ -7527,24 +8492,28 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
7527
8492
|
}
|
|
7528
8493
|
if (y_non_contig) {
|
|
7529
8494
|
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
|
|
7530
|
-
ctx->prealloc_y_last_tensor_used != src1
|
|
8495
|
+
ctx->prealloc_y_last_tensor_used != src1 ||
|
|
8496
|
+
ctx->prealloc_y_last_decode_vector_staging) {
|
|
7531
8497
|
if (ctx->prealloc_y_need_sync) {
|
|
7532
8498
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
7533
8499
|
}
|
|
7534
8500
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
|
|
7535
8501
|
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
|
7536
8502
|
ctx->prealloc_y_last_tensor_used = src1;
|
|
8503
|
+
ctx->prealloc_y_last_decode_vector_staging = false;
|
|
7537
8504
|
}
|
|
7538
8505
|
}
|
|
7539
8506
|
if (quantize_y) {
|
|
7540
8507
|
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
|
|
7541
|
-
ctx->prealloc_y_last_tensor_used != src1
|
|
8508
|
+
ctx->prealloc_y_last_tensor_used != src1 ||
|
|
8509
|
+
ctx->prealloc_y_last_decode_vector_staging) {
|
|
7542
8510
|
if (ctx->prealloc_y_need_sync) {
|
|
7543
8511
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
7544
8512
|
}
|
|
7545
8513
|
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
|
|
7546
8514
|
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
|
|
7547
8515
|
ctx->prealloc_y_last_tensor_used = src1;
|
|
8516
|
+
ctx->prealloc_y_last_decode_vector_staging = false;
|
|
7548
8517
|
}
|
|
7549
8518
|
}
|
|
7550
8519
|
|
|
@@ -7585,8 +8554,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
|
|
7585
8554
|
return false;
|
|
7586
8555
|
}
|
|
7587
8556
|
|
|
7588
|
-
//
|
|
7589
|
-
|
|
8557
|
+
// q6_k only has 2-byte alignment which makes it somewhat problematic,
|
|
8558
|
+
// using MMVQ is only a win on Intel.
|
|
8559
|
+
bool mmvq_q6 = device->vendor_id == VK_VENDOR_ID_INTEL;
|
|
8560
|
+
if (src0_type == GGML_TYPE_Q6_K && !mmvq_q6) {
|
|
7590
8561
|
return false;
|
|
7591
8562
|
}
|
|
7592
8563
|
|
|
@@ -7598,7 +8569,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
|
|
7598
8569
|
// Quantization overhead is not worth it for small k
|
|
7599
8570
|
switch (device->vendor_id) {
|
|
7600
8571
|
case VK_VENDOR_ID_NVIDIA:
|
|
7601
|
-
if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
|
|
8572
|
+
if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
|
|
7602
8573
|
return true;
|
|
7603
8574
|
}
|
|
7604
8575
|
|
|
@@ -7625,20 +8596,21 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
|
|
7625
8596
|
return true;
|
|
7626
8597
|
}
|
|
7627
8598
|
case VK_VENDOR_ID_INTEL:
|
|
7628
|
-
if (
|
|
7629
|
-
|
|
8599
|
+
if (device->architecture == vk_device_architecture::INTEL_XE2) {
|
|
8600
|
+
if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
|
|
8601
|
+
return true;
|
|
8602
|
+
}
|
|
7630
8603
|
}
|
|
7631
8604
|
|
|
7632
8605
|
if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) {
|
|
7633
|
-
// Intel Windows proprietary driver
|
|
7634
|
-
|
|
7635
|
-
|
|
7636
|
-
|
|
7637
|
-
|
|
7638
|
-
|
|
7639
|
-
|
|
7640
|
-
|
|
7641
|
-
}
|
|
8606
|
+
// Intel Windows proprietary driver MMVQ performance for !Q2/Q3/Q6 is worse than fp16,
|
|
8607
|
+
// see https://github.com/ggml-org/llama.cpp/issues/17628 and
|
|
8608
|
+
// https://github.com/ggml-org/llama.cpp/pull/23056
|
|
8609
|
+
return false;
|
|
8610
|
+
}
|
|
8611
|
+
|
|
8612
|
+
if (k < 2048) {
|
|
8613
|
+
return false;
|
|
7642
8614
|
}
|
|
7643
8615
|
|
|
7644
8616
|
switch (src0_type) {
|
|
@@ -7799,24 +8771,28 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
7799
8771
|
if (y_non_contig) {
|
|
7800
8772
|
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
|
|
7801
8773
|
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
|
|
7802
|
-
ctx->prealloc_y_last_tensor_used != src1
|
|
8774
|
+
ctx->prealloc_y_last_tensor_used != src1 ||
|
|
8775
|
+
ctx->prealloc_y_last_decode_vector_staging) {
|
|
7803
8776
|
if (ctx->prealloc_y_need_sync) {
|
|
7804
8777
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
7805
8778
|
}
|
|
7806
8779
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
|
|
7807
8780
|
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
|
7808
8781
|
ctx->prealloc_y_last_tensor_used = src1;
|
|
8782
|
+
ctx->prealloc_y_last_decode_vector_staging = false;
|
|
7809
8783
|
}
|
|
7810
8784
|
}
|
|
7811
8785
|
if (quantize_y) {
|
|
7812
8786
|
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
|
|
7813
|
-
ctx->prealloc_y_last_tensor_used != src1
|
|
8787
|
+
ctx->prealloc_y_last_tensor_used != src1 ||
|
|
8788
|
+
ctx->prealloc_y_last_decode_vector_staging) {
|
|
7814
8789
|
if (ctx->prealloc_y_need_sync) {
|
|
7815
8790
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
7816
8791
|
}
|
|
7817
8792
|
ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
|
|
7818
8793
|
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
|
|
7819
8794
|
ctx->prealloc_y_last_tensor_used = src1;
|
|
8795
|
+
ctx->prealloc_y_last_decode_vector_staging = false;
|
|
7820
8796
|
}
|
|
7821
8797
|
}
|
|
7822
8798
|
|
|
@@ -8060,25 +9036,87 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|
|
8060
9036
|
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
|
|
8061
9037
|
}
|
|
8062
9038
|
|
|
8063
|
-
// compute
|
|
8064
|
-
vk_mat_vec_nc_push_constants pc = {
|
|
8065
|
-
(uint32_t)ne00, (uint32_t)ne01,
|
|
8066
|
-
row_stride_x, channel_stride_x, channel_stride_y,
|
|
8067
|
-
(uint32_t)(ne12 / ne02), (uint32_t)ne12,
|
|
8068
|
-
0, 0,
|
|
8069
|
-
nb03, nb13, nb23, fusion_flags
|
|
8070
|
-
};
|
|
9039
|
+
// compute
|
|
9040
|
+
vk_mat_vec_nc_push_constants pc = {
|
|
9041
|
+
(uint32_t)ne00, (uint32_t)ne01,
|
|
9042
|
+
row_stride_x, channel_stride_x, channel_stride_y,
|
|
9043
|
+
(uint32_t)(ne12 / ne02), (uint32_t)ne12,
|
|
9044
|
+
0, 0,
|
|
9045
|
+
nb03, nb13, nb23, fusion_flags
|
|
9046
|
+
};
|
|
9047
|
+
|
|
9048
|
+
init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
|
|
9049
|
+
|
|
9050
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
9051
|
+
{
|
|
9052
|
+
d_Qx,
|
|
9053
|
+
d_Qy,
|
|
9054
|
+
d_D,
|
|
9055
|
+
d_F0,
|
|
9056
|
+
d_F1,
|
|
9057
|
+
}, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
|
|
9058
|
+
}
|
|
9059
|
+
|
|
9060
|
+
static int ggml_vk_fwht_pipeline_idx(int64_t n) {
|
|
9061
|
+
switch (n) {
|
|
9062
|
+
case 64: return 0;
|
|
9063
|
+
case 128: return 1;
|
|
9064
|
+
case 256: return 2;
|
|
9065
|
+
case 512: return 3;
|
|
9066
|
+
default: return -1;
|
|
9067
|
+
}
|
|
9068
|
+
}
|
|
9069
|
+
|
|
9070
|
+
static bool ggml_vk_can_use_fwht(const ggml_backend_vk_context * ctx, const ggml_tensor * src1, const ggml_tensor * dst) {
|
|
9071
|
+
if (ctx->num_additional_fused_ops != 0) {
|
|
9072
|
+
return false;
|
|
9073
|
+
}
|
|
9074
|
+
|
|
9075
|
+
if (ggml_get_op_params_i32(dst, 1) != GGML_HINT_SRC0_IS_HADAMARD) {
|
|
9076
|
+
return false;
|
|
9077
|
+
}
|
|
9078
|
+
|
|
9079
|
+
const int idx = ggml_vk_fwht_pipeline_idx(src1->ne[0]);
|
|
9080
|
+
if (idx < 0 || ctx->device->pipeline_fwht_f32[idx] == nullptr) {
|
|
9081
|
+
return false;
|
|
9082
|
+
}
|
|
9083
|
+
|
|
9084
|
+
if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
|
9085
|
+
return false;
|
|
9086
|
+
}
|
|
9087
|
+
|
|
9088
|
+
if (!ggml_is_contiguous(src1)) {
|
|
9089
|
+
return false;
|
|
9090
|
+
}
|
|
9091
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
9092
|
+
|
|
9093
|
+
return true;
|
|
9094
|
+
}
|
|
9095
|
+
|
|
9096
|
+
static void ggml_vk_fwht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src, ggml_tensor * dst) {
|
|
9097
|
+
const int idx = ggml_vk_fwht_pipeline_idx(src->ne[0]);
|
|
9098
|
+
vk_pipeline pipeline = ctx->device->pipeline_fwht_f32[idx];
|
|
9099
|
+
|
|
9100
|
+
const uint32_t rows_per_workgroup = 4;
|
|
9101
|
+
const uint32_t n_rows = (uint32_t)ggml_nrows(src);
|
|
9102
|
+
const uint32_t max_workgroups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
|
|
9103
|
+
|
|
9104
|
+
const uint32_t total_workgroups = CEIL_DIV(n_rows, rows_per_workgroup);
|
|
9105
|
+
const uint32_t workgroups_x = std::min(total_workgroups, max_workgroups_x);
|
|
9106
|
+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
9107
|
+
|
|
9108
|
+
const vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src, true);
|
|
9109
|
+
const vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
|
|
8071
9110
|
|
|
8072
|
-
|
|
9111
|
+
vk_op_fwht_push_constants pc = {
|
|
9112
|
+
n_rows,
|
|
9113
|
+
0,
|
|
9114
|
+
0,
|
|
9115
|
+
1.0f / std::sqrt((float)src->ne[0]),
|
|
9116
|
+
};
|
|
9117
|
+
init_pushconst_tensor_offsets(ctx, pc, src, nullptr, nullptr, nullptr, dst);
|
|
8073
9118
|
|
|
8074
|
-
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
8075
|
-
{
|
|
8076
|
-
d_Qx,
|
|
8077
|
-
d_Qy,
|
|
8078
|
-
d_D,
|
|
8079
|
-
d_F0,
|
|
8080
|
-
d_F1,
|
|
8081
|
-
}, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
|
|
9119
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc, { workgroups_x, 1, 1 });
|
|
8082
9120
|
}
|
|
8083
9121
|
|
|
8084
9122
|
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
|
@@ -8114,6 +9152,8 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
8114
9152
|
|
|
8115
9153
|
m_offset += cur_M_size;
|
|
8116
9154
|
}
|
|
9155
|
+
} else if (ggml_vk_can_use_fwht(ctx, src1, dst)) {
|
|
9156
|
+
ggml_vk_fwht(ctx, subctx, src1, dst);
|
|
8117
9157
|
} else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
|
|
8118
9158
|
// detect 0213 permutation, and batch size of 1
|
|
8119
9159
|
src0->nb[0] <= src0->nb[2] &&
|
|
@@ -8203,12 +9243,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
8203
9243
|
// Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
|
|
8204
9244
|
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
|
8205
9245
|
!ggml_vk_dim01_contiguous(src0);
|
|
8206
|
-
|
|
9246
|
+
// If src0 is BF16, try to use a BF16 x BF16 multiply
|
|
9247
|
+
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
|
|
9248
|
+
#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
|
|
9249
|
+
// B must already be, or be convertible to, the matmul B type used by this path.
|
|
9250
|
+
const bool y_decode_vector_supported = ctx->device->coopmat2_decode_vector &&
|
|
9251
|
+
(f16_type != GGML_TYPE_BF16 || ctx->device->coopmat2_bf16_support) &&
|
|
9252
|
+
(src1->type == GGML_TYPE_F32 || src1->type == f16_type);
|
|
9253
|
+
// If B is copied to prealloc_y, we can choose a 4-element-aligned row stride.
|
|
9254
|
+
const bool y_decode_vector_uses_prealloc = !ggml_vk_dim01_contiguous(src1) || src1->type != f16_type;
|
|
9255
|
+
// Direct B reads are safe only if row starts and the original buffer offset are 4-element aligned.
|
|
9256
|
+
const bool y_decode_vector_aligned =
|
|
9257
|
+
(ne10 % 4 == 0) &&
|
|
9258
|
+
(y_decode_vector_uses_prealloc || get_misalign_bytes(ctx, src1) % (4 * ggml_type_size(src1->type)) == 0);
|
|
9259
|
+
// Stage B only when decode-vector is available and direct B reads would be misaligned.
|
|
9260
|
+
const bool y_decode_vector_staging = y_decode_vector_supported && !y_decode_vector_aligned;
|
|
9261
|
+
#else
|
|
9262
|
+
const bool y_decode_vector_staging = false;
|
|
9263
|
+
#endif
|
|
9264
|
+
const bool y_non_contig = y_decode_vector_staging ||
|
|
9265
|
+
(ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
|
8207
9266
|
(src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
|
|
8208
9267
|
!ggml_vk_dim01_contiguous(src1);
|
|
8209
9268
|
|
|
8210
|
-
|
|
8211
|
-
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
|
|
9269
|
+
const uint32_t y_staged_row_stride = y_decode_vector_staging ? (uint32_t)ggml_vk_align_size(ne10, 4) : (uint32_t)ne10;
|
|
8212
9270
|
|
|
8213
9271
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
|
8214
9272
|
|
|
@@ -8234,10 +9292,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
8234
9292
|
// Not implemented
|
|
8235
9293
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
|
8236
9294
|
|
|
8237
|
-
const
|
|
9295
|
+
const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type);
|
|
9296
|
+
|
|
9297
|
+
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type, effective_src1_type));
|
|
8238
9298
|
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
|
|
8239
9299
|
|
|
8240
|
-
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
|
|
9300
|
+
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type);
|
|
8241
9301
|
|
|
8242
9302
|
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
|
|
8243
9303
|
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
|
|
@@ -8245,11 +9305,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
8245
9305
|
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
|
8246
9306
|
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
|
8247
9307
|
const uint64_t x_ne = ggml_nelements(src0);
|
|
8248
|
-
const uint64_t y_ne =
|
|
9308
|
+
const uint64_t y_ne = (uint64_t)y_staged_row_stride * padded_n * ne12 * ne13;
|
|
8249
9309
|
const uint64_t d_ne = ggml_nelements(dst);
|
|
8250
9310
|
|
|
8251
9311
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
|
8252
|
-
const uint64_t qy_sz = ggml_type_size(src1->type) *
|
|
9312
|
+
const uint64_t qy_sz = ggml_type_size(src1->type) * ggml_nelements(src1) / ggml_blck_size(src1->type);
|
|
8253
9313
|
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
|
8254
9314
|
const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
|
|
8255
9315
|
const uint64_t ids_sz = nbi2;
|
|
@@ -8259,13 +9319,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
8259
9319
|
vk_pipeline to_fp16_vk_1 = nullptr;
|
|
8260
9320
|
vk_pipeline to_q8_1 = nullptr;
|
|
8261
9321
|
|
|
9322
|
+
auto make_y_staged_dst = [&]() {
|
|
9323
|
+
ggml_tensor y_staged_dst = *src1;
|
|
9324
|
+
y_staged_dst.type = f16_type;
|
|
9325
|
+
y_staged_dst.nb[0] = ggml_type_size(f16_type);
|
|
9326
|
+
y_staged_dst.nb[1] = y_staged_dst.nb[0] * y_staged_row_stride;
|
|
9327
|
+
y_staged_dst.nb[2] = y_staged_dst.nb[1] * padded_n;
|
|
9328
|
+
y_staged_dst.nb[3] = y_staged_dst.nb[2] * y_staged_dst.ne[2];
|
|
9329
|
+
return y_staged_dst;
|
|
9330
|
+
};
|
|
9331
|
+
|
|
8262
9332
|
if (x_non_contig) {
|
|
8263
9333
|
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
|
8264
9334
|
} else {
|
|
8265
9335
|
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
|
8266
9336
|
}
|
|
8267
9337
|
if (y_non_contig) {
|
|
8268
|
-
|
|
9338
|
+
ggml_tensor y_staged_dst;
|
|
9339
|
+
const ggml_tensor * y_staged_dst_ptr = nullptr;
|
|
9340
|
+
if (y_decode_vector_staging) {
|
|
9341
|
+
y_staged_dst = make_y_staged_dst();
|
|
9342
|
+
y_staged_dst_ptr = &y_staged_dst;
|
|
9343
|
+
}
|
|
9344
|
+
|
|
9345
|
+
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, y_staged_dst_ptr, f16_type);
|
|
8269
9346
|
} else {
|
|
8270
9347
|
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
|
8271
9348
|
}
|
|
@@ -8383,30 +9460,47 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
8383
9460
|
}
|
|
8384
9461
|
if (y_non_contig) {
|
|
8385
9462
|
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
|
|
8386
|
-
ctx->prealloc_y_last_tensor_used != src1
|
|
9463
|
+
ctx->prealloc_y_last_tensor_used != src1 ||
|
|
9464
|
+
ctx->prealloc_y_last_decode_vector_staging != y_decode_vector_staging) {
|
|
8387
9465
|
if (ctx->prealloc_y_need_sync) {
|
|
8388
9466
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
8389
9467
|
}
|
|
8390
|
-
|
|
9468
|
+
if (y_decode_vector_staging) {
|
|
9469
|
+
const ggml_tensor y_staged_dst = make_y_staged_dst();
|
|
9470
|
+
const uint32_t y_staged_dst_type_size = ggml_type_size(y_staged_dst.type);
|
|
9471
|
+
ggml_vk_cpy_to_strided(
|
|
9472
|
+
ctx, subctx, to_fp16_vk_1, src1,
|
|
9473
|
+
ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0),
|
|
9474
|
+
(uint32_t)(y_staged_dst.nb[0] / y_staged_dst_type_size),
|
|
9475
|
+
(uint32_t)(y_staged_dst.nb[1] / y_staged_dst_type_size),
|
|
9476
|
+
(uint32_t)(y_staged_dst.nb[2] / y_staged_dst_type_size),
|
|
9477
|
+
(uint32_t)(y_staged_dst.nb[3] / y_staged_dst_type_size));
|
|
9478
|
+
} else {
|
|
9479
|
+
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
|
|
9480
|
+
}
|
|
8391
9481
|
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
|
8392
9482
|
ctx->prealloc_y_last_tensor_used = src1;
|
|
9483
|
+
ctx->prealloc_y_last_decode_vector_staging = y_decode_vector_staging;
|
|
8393
9484
|
}
|
|
8394
9485
|
}
|
|
8395
9486
|
if (quantize_y) {
|
|
8396
9487
|
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
|
|
8397
|
-
ctx->prealloc_y_last_tensor_used != src1
|
|
9488
|
+
ctx->prealloc_y_last_tensor_used != src1 ||
|
|
9489
|
+
ctx->prealloc_y_last_decode_vector_staging) {
|
|
8398
9490
|
if (ctx->prealloc_y_need_sync) {
|
|
8399
9491
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
8400
9492
|
}
|
|
8401
9493
|
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
|
|
8402
9494
|
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
|
|
8403
9495
|
ctx->prealloc_y_last_tensor_used = src1;
|
|
9496
|
+
ctx->prealloc_y_last_decode_vector_staging = false;
|
|
8404
9497
|
}
|
|
8405
9498
|
}
|
|
8406
9499
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
8407
9500
|
|
|
8408
9501
|
uint32_t stride_batch_x = ne00*ne01;
|
|
8409
|
-
uint32_t
|
|
9502
|
+
uint32_t stride_b_y = y_decode_vector_staging ? y_staged_row_stride : ne10;
|
|
9503
|
+
uint32_t stride_batch_y = y_decode_vector_staging ? y_staged_row_stride * padded_n : ne10*ne11;
|
|
8410
9504
|
|
|
8411
9505
|
if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
|
|
8412
9506
|
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
|
|
@@ -8421,7 +9515,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
8421
9515
|
ctx, subctx, pipeline,
|
|
8422
9516
|
{ d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
|
|
8423
9517
|
{ d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,
|
|
8424
|
-
ne01, ne21, ne10, ne10,
|
|
9518
|
+
ne01, ne21, ne10, ne10, stride_b_y, ne01,
|
|
8425
9519
|
stride_batch_x, stride_batch_y, ne20*ne21,
|
|
8426
9520
|
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
|
|
8427
9521
|
); // NOLINT
|
|
@@ -8579,24 +9673,28 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
8579
9673
|
if (y_non_contig) {
|
|
8580
9674
|
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
|
|
8581
9675
|
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
|
|
8582
|
-
ctx->prealloc_y_last_tensor_used != src1
|
|
9676
|
+
ctx->prealloc_y_last_tensor_used != src1 ||
|
|
9677
|
+
ctx->prealloc_y_last_decode_vector_staging) {
|
|
8583
9678
|
if (ctx->prealloc_y_need_sync) {
|
|
8584
9679
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
8585
9680
|
}
|
|
8586
9681
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
|
|
8587
9682
|
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
|
8588
9683
|
ctx->prealloc_y_last_tensor_used = src1;
|
|
9684
|
+
ctx->prealloc_y_last_decode_vector_staging = false;
|
|
8589
9685
|
}
|
|
8590
9686
|
}
|
|
8591
9687
|
if (quantize_y) {
|
|
8592
9688
|
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
|
|
8593
|
-
ctx->prealloc_y_last_tensor_used != src1
|
|
9689
|
+
ctx->prealloc_y_last_tensor_used != src1 ||
|
|
9690
|
+
ctx->prealloc_y_last_decode_vector_staging) {
|
|
8594
9691
|
if (ctx->prealloc_y_need_sync) {
|
|
8595
9692
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
8596
9693
|
}
|
|
8597
9694
|
ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
|
|
8598
9695
|
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
|
|
8599
9696
|
ctx->prealloc_y_last_tensor_used = src1;
|
|
9697
|
+
ctx->prealloc_y_last_decode_vector_staging = false;
|
|
8600
9698
|
}
|
|
8601
9699
|
}
|
|
8602
9700
|
|
|
@@ -8687,14 +9785,18 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8687
9785
|
}
|
|
8688
9786
|
}
|
|
8689
9787
|
|
|
8690
|
-
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
|
|
9788
|
+
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type) {
|
|
8691
9789
|
GGML_UNUSED(f32acc);
|
|
9790
|
+
GGML_UNUSED(v_type);
|
|
8692
9791
|
// Needs to be kept up to date on shader changes
|
|
8693
9792
|
const uint32_t wg_size = params.workgroup_size;
|
|
8694
9793
|
const uint32_t Br = params.block_rows;
|
|
8695
9794
|
const uint32_t Bc = params.block_cols;
|
|
8696
9795
|
|
|
8697
|
-
|
|
9796
|
+
// BF16 uses the fp32 shader (FLOAT_TYPE=float)
|
|
9797
|
+
const uint32_t float_type_size = (device->fp16 && k_type != GGML_TYPE_BF16) ? sizeof(ggml_fp16_t) : sizeof(float);
|
|
9798
|
+
|
|
9799
|
+
const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type);
|
|
8698
9800
|
|
|
8699
9801
|
// tmpsh is overestimated slightly
|
|
8700
9802
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
@@ -8702,20 +9804,38 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
|
|
8702
9804
|
|
|
8703
9805
|
const uint32_t masksh = Bc * (Br + 1) * float_type_size;
|
|
8704
9806
|
|
|
8705
|
-
|
|
9807
|
+
uint32_t Qf, kvsh, kblocksh_size;
|
|
9808
|
+
if (mmq) {
|
|
9809
|
+
// block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds
|
|
9810
|
+
const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size;
|
|
9811
|
+
Qf = Br * (hsk / 32) * block_b_size;
|
|
9812
|
+
|
|
9813
|
+
// kvsh uses D = HSV (K goes through kblocksh instead)
|
|
9814
|
+
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
|
9815
|
+
|
|
9816
|
+
// The mixed MMQ shader uses a superset block_a_cache that fits every
|
|
9817
|
+
// FA-supported quant: int32_t qs[8] + uint32_t qh + FLOAT_TYPEV2 dm.
|
|
9818
|
+
// Single-scale types leave dm.y unused; non-Q5_* leave qh unused.
|
|
9819
|
+
const uint32_t block_a_size = 8 * sizeof(int32_t) + sizeof(uint32_t) + 2 * float_type_size;
|
|
9820
|
+
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
|
|
9821
|
+
} else {
|
|
9822
|
+
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
|
9823
|
+
|
|
9824
|
+
const uint32_t D = std::max(hsk, hsv);
|
|
9825
|
+
kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
|
8706
9826
|
|
|
8707
|
-
|
|
8708
|
-
|
|
9827
|
+
kblocksh_size = 0;
|
|
9828
|
+
}
|
|
8709
9829
|
|
|
8710
|
-
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
|
|
9830
|
+
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size;
|
|
8711
9831
|
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
8712
9832
|
|
|
8713
|
-
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
|
9833
|
+
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported);
|
|
8714
9834
|
|
|
8715
9835
|
return supported;
|
|
8716
9836
|
}
|
|
8717
9837
|
|
|
8718
|
-
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
|
|
9838
|
+
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type) {
|
|
8719
9839
|
// Needs to be kept up to date on shader changes
|
|
8720
9840
|
const uint32_t Br = params.block_rows;
|
|
8721
9841
|
const uint32_t Bc = params.block_cols;
|
|
@@ -8745,8 +9865,10 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|
|
8745
9865
|
const uint32_t vsh_stride = MatBc / 4 * row_split;
|
|
8746
9866
|
const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4;
|
|
8747
9867
|
|
|
9868
|
+
// BF16 PVMat accumulator is f32 (no bf16 accumulator support), so pvsh is vec4 (16 bytes)
|
|
9869
|
+
const uint32_t pvsh_elem_size = (k_type == GGML_TYPE_BF16) ? 16u : f16vec4;
|
|
8748
9870
|
const uint32_t osh_stride = params.row_split * MatBr / 4;
|
|
8749
|
-
const uint32_t pvsh = MatBc * osh_stride *
|
|
9871
|
+
const uint32_t pvsh = MatBc * osh_stride * pvsh_elem_size;
|
|
8750
9872
|
|
|
8751
9873
|
const uint32_t slope = Br * acctype;
|
|
8752
9874
|
|
|
@@ -8809,19 +9931,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8809
9931
|
|
|
8810
9932
|
assert(dst->type == GGML_TYPE_F32);
|
|
8811
9933
|
assert(q->type == GGML_TYPE_F32);
|
|
8812
|
-
assert(k->type == v->type);
|
|
8813
|
-
|
|
8814
9934
|
uint32_t gqa_ratio = 1;
|
|
8815
9935
|
uint32_t qk_ratio = neq2 / nek2;
|
|
8816
9936
|
uint32_t workgroups_x = (uint32_t)neq1;
|
|
8817
9937
|
uint32_t workgroups_y = (uint32_t)neq2;
|
|
8818
9938
|
uint32_t workgroups_z = (uint32_t)neq3;
|
|
8819
9939
|
|
|
8820
|
-
const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;
|
|
9940
|
+
const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32 || k->type == GGML_TYPE_BF16;
|
|
8821
9941
|
|
|
8822
9942
|
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
|
8823
9943
|
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
|
8824
|
-
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
|
|
9944
|
+
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, v->type, f32acc);
|
|
8825
9945
|
const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
|
|
8826
9946
|
|
|
8827
9947
|
if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
|
@@ -8834,7 +9954,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8834
9954
|
workgroups_y /= gqa_ratio;
|
|
8835
9955
|
}
|
|
8836
9956
|
|
|
8837
|
-
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
|
|
9957
|
+
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc);
|
|
8838
9958
|
|
|
8839
9959
|
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
|
8840
9960
|
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
|
@@ -8873,13 +9993,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8873
9993
|
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
|
|
8874
9994
|
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
|
|
8875
9995
|
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
|
|
8876
|
-
mask != nullptr, use_mask_opt, logit_softcap != 0);
|
|
9996
|
+
mask != nullptr, use_mask_opt, logit_softcap != 0, k->type, v->type);
|
|
8877
9997
|
|
|
8878
9998
|
vk_pipeline pipeline = nullptr;
|
|
8879
9999
|
|
|
8880
10000
|
{
|
|
8881
|
-
std::lock_guard<std::
|
|
8882
|
-
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16
|
|
10001
|
+
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
|
|
10002
|
+
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16;
|
|
8883
10003
|
auto it = pipelines.find(fa_pipeline_state);
|
|
8884
10004
|
if (it != pipelines.end()) {
|
|
8885
10005
|
pipeline = it->second;
|
|
@@ -8942,13 +10062,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8942
10062
|
|
|
8943
10063
|
vk_pipeline pipeline_fa_mask_opt = nullptr;
|
|
8944
10064
|
if (use_mask_opt) {
|
|
8945
|
-
|
|
8946
|
-
|
|
8947
|
-
|
|
8948
|
-
|
|
8949
|
-
|
|
8950
|
-
|
|
8951
|
-
|
|
10065
|
+
{
|
|
10066
|
+
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
|
|
10067
|
+
auto &pipelines = ctx->device->pipeline_fa_mask_opt;
|
|
10068
|
+
auto it = pipelines.find({Br, Bc});
|
|
10069
|
+
if (it != pipelines.end()) {
|
|
10070
|
+
pipeline_fa_mask_opt = it->second;
|
|
10071
|
+
} else {
|
|
10072
|
+
pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
|
|
10073
|
+
}
|
|
8952
10074
|
}
|
|
8953
10075
|
assert(pipeline_fa_mask_opt);
|
|
8954
10076
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
|
|
@@ -9059,10 +10181,23 @@ static vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, u
|
|
|
9059
10181
|
// so small convolutions will still choose a smaller tile.
|
|
9060
10182
|
const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;
|
|
9061
10183
|
|
|
9062
|
-
|
|
10184
|
+
// 128x128 isn't used with cm1 due to shared memory size; fall through to a smaller tile.
|
|
10185
|
+
bool allow_128x128 = true;
|
|
10186
|
+
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
10187
|
+
if (!ctx->device->coopmat2 && ctx->device->coopmat_support && ctx->device->coopmat_support_16x16x16_f16acc) {
|
|
10188
|
+
allow_128x128 = false;
|
|
10189
|
+
}
|
|
10190
|
+
#endif
|
|
10191
|
+
|
|
10192
|
+
if (allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) {
|
|
9063
10193
|
return CONV_SHAPE_128x128;
|
|
9064
10194
|
} else if (K <= 32 && n_tiles(CONV_SHAPE_32x256) >= shader_core_count * 2) {
|
|
9065
10195
|
return CONV_SHAPE_32x256;
|
|
10196
|
+
} else if (K <= 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) {
|
|
10197
|
+
return CONV_SHAPE_64x128;
|
|
10198
|
+
} else if (!allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) {
|
|
10199
|
+
// cm1 fallback for large K when 128x128 isn't available
|
|
10200
|
+
return CONV_SHAPE_64x128;
|
|
9066
10201
|
} else {
|
|
9067
10202
|
return CONV_SHAPE_64x32;
|
|
9068
10203
|
}
|
|
@@ -9234,7 +10369,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
9234
10369
|
return nullptr;
|
|
9235
10370
|
case GGML_OP_REPEAT:
|
|
9236
10371
|
if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
|
|
9237
|
-
return ctx->device->
|
|
10372
|
+
return ctx->device->pipeline_repeat_i32;
|
|
10373
|
+
}
|
|
10374
|
+
if (ggml_type_size(src0->type) == 2 && ggml_type_size(dst->type) == 2) {
|
|
10375
|
+
return ctx->device->pipeline_repeat_i16;
|
|
9238
10376
|
}
|
|
9239
10377
|
return nullptr;
|
|
9240
10378
|
case GGML_OP_REPEAT_BACK:
|
|
@@ -9466,7 +10604,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
9466
10604
|
vk_pipeline pipeline = nullptr;
|
|
9467
10605
|
|
|
9468
10606
|
{
|
|
9469
|
-
std::lock_guard<std::
|
|
10607
|
+
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
|
|
9470
10608
|
auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
|
|
9471
10609
|
if (it != ctx->device->pipeline_solve_tri_f32.end()) {
|
|
9472
10610
|
pipeline = it->second;
|
|
@@ -9555,7 +10693,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
9555
10693
|
return nullptr;
|
|
9556
10694
|
case GGML_OP_SSM_CONV:
|
|
9557
10695
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
9558
|
-
|
|
10696
|
+
switch (ctx->num_additional_fused_ops) {
|
|
10697
|
+
case 0: return ctx->device->pipeline_ssm_conv_f32;
|
|
10698
|
+
case 1: return ctx->device->pipeline_ssm_conv_silu_f32;
|
|
10699
|
+
case 2: return ctx->device->pipeline_ssm_conv_bias_silu_f32;
|
|
10700
|
+
default: return nullptr;
|
|
10701
|
+
}
|
|
9559
10702
|
}
|
|
9560
10703
|
return nullptr;
|
|
9561
10704
|
case GGML_OP_OPT_STEP_ADAMW:
|
|
@@ -9589,7 +10732,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
9589
10732
|
uint32_t p1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 3) : 0;
|
|
9590
10733
|
uint32_t d0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 4) : 1;
|
|
9591
10734
|
uint32_t d1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 5) : 1;
|
|
9592
|
-
|
|
10735
|
+
|
|
10736
|
+
// tile-aligned shapes let the shader skip bounds checks
|
|
10737
|
+
const uint32_t Cin = (uint32_t)src1->ne[2];
|
|
10738
|
+
const uint32_t CRS = Cin * KW * KH;
|
|
10739
|
+
const uint32_t BS_K = vk_conv_block_sizes[shape].K;
|
|
10740
|
+
const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
|
|
10741
|
+
const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
|
|
10742
|
+
const uint32_t aligned = ((K % BS_K == 0) &&
|
|
10743
|
+
(CRS % BS_CRS == 0) &&
|
|
10744
|
+
(NPQ % BS_NPQ == 0)) ? 1u : 0u;
|
|
10745
|
+
|
|
10746
|
+
vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH, aligned);
|
|
9593
10747
|
|
|
9594
10748
|
std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr;
|
|
9595
10749
|
if (op == GGML_OP_CONV_2D) {
|
|
@@ -9609,7 +10763,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
9609
10763
|
vk_pipeline pipeline = nullptr;
|
|
9610
10764
|
|
|
9611
10765
|
{
|
|
9612
|
-
std::lock_guard<std::
|
|
10766
|
+
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
|
|
9613
10767
|
auto it = pipelines->find(conv2d_pipeline_state);
|
|
9614
10768
|
if (it != pipelines->end()) {
|
|
9615
10769
|
pipeline = it->second;
|
|
@@ -9656,6 +10810,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
9656
10810
|
if (dst->type == GGML_TYPE_F32) {
|
|
9657
10811
|
return ctx->device->pipeline_fill_f32;
|
|
9658
10812
|
}
|
|
10813
|
+
if (dst->type == GGML_TYPE_F16) {
|
|
10814
|
+
return ctx->device->pipeline_fill_f16;
|
|
10815
|
+
}
|
|
9659
10816
|
return nullptr;
|
|
9660
10817
|
default:
|
|
9661
10818
|
return nullptr;
|
|
@@ -9733,6 +10890,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
|
|
|
9733
10890
|
GGML_UNUSED(src3);
|
|
9734
10891
|
}
|
|
9735
10892
|
|
|
10893
|
+
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_rope_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
|
|
10894
|
+
p.a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
|
10895
|
+
p.d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
|
10896
|
+
|
|
10897
|
+
GGML_UNUSED(src1);
|
|
10898
|
+
GGML_UNUSED(src2);
|
|
10899
|
+
GGML_UNUSED(src3);
|
|
10900
|
+
}
|
|
10901
|
+
|
|
9736
10902
|
template<typename PC>
|
|
9737
10903
|
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) {
|
|
9738
10904
|
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
|
@@ -9876,7 +11042,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
9876
11042
|
|
|
9877
11043
|
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
|
|
9878
11044
|
|
|
9879
|
-
|
|
11045
|
+
const uint32_t CHW = IC * KH * KW;
|
|
11046
|
+
// Cap X workgroups to limit concurrent IC channel reads.
|
|
11047
|
+
// The shader loops over X to cover the full CHW dimension.
|
|
11048
|
+
// AMD prefers a lower limit
|
|
11049
|
+
const uint32_t min_cap = ctx->device->vendor_id == VK_VENDOR_ID_AMD ? 512u : 4096u;
|
|
11050
|
+
const uint32_t x_elements = std::min(CHW, std::max(min_cap, OW * KH * KW));
|
|
11051
|
+
elements = { x_elements, OW, OH * batch };
|
|
9880
11052
|
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
|
9881
11053
|
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
|
9882
11054
|
} break;
|
|
@@ -10385,6 +11557,9 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
|
|
10385
11557
|
const uint32_t n_tokens = (uint32_t)src_v->ne[2];
|
|
10386
11558
|
const uint32_t n_seqs = (uint32_t)src_v->ne[3];
|
|
10387
11559
|
|
|
11560
|
+
// K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs].
|
|
11561
|
+
const uint32_t K = (uint32_t)ggml_get_op_params_i32(dst, 0);
|
|
11562
|
+
|
|
10388
11563
|
const uint32_t s_off = S_v * H * n_tokens * n_seqs;
|
|
10389
11564
|
|
|
10390
11565
|
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
|
|
@@ -10418,12 +11593,13 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
|
|
10418
11593
|
sv1, sv2, sv3,
|
|
10419
11594
|
sb1, sb2, sb3,
|
|
10420
11595
|
neq1, rq3,
|
|
10421
|
-
scale
|
|
11596
|
+
scale,
|
|
11597
|
+
K
|
|
10422
11598
|
};
|
|
10423
11599
|
|
|
10424
11600
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
10425
11601
|
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
|
|
10426
|
-
pc, { H, n_seqs,
|
|
11602
|
+
pc, { H, n_seqs, S_v });
|
|
10427
11603
|
}
|
|
10428
11604
|
|
|
10429
11605
|
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
|
@@ -10482,11 +11658,28 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
10482
11658
|
pc, elements);
|
|
10483
11659
|
}
|
|
10484
11660
|
|
|
10485
|
-
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
10486
|
-
|
|
10487
|
-
const ggml_tensor *
|
|
11661
|
+
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
|
11662
|
+
ggml_tensor * conv = cgraph->nodes[node_idx];
|
|
11663
|
+
const ggml_tensor * src0 = conv->src[0];
|
|
11664
|
+
const ggml_tensor * src1 = conv->src[1];
|
|
11665
|
+
|
|
11666
|
+
// Pick the destination tensor (last node in the fused chain) and the optional bias.
|
|
11667
|
+
// Fusion modes: 0 = ssm_conv, 1 = ssm_conv+silu, 2 = ssm_conv+add(bias)+silu.
|
|
11668
|
+
ggml_tensor * dst = conv;
|
|
11669
|
+
const ggml_tensor * bias = nullptr;
|
|
10488
11670
|
|
|
10489
|
-
|
|
11671
|
+
if (ctx->num_additional_fused_ops == 1) {
|
|
11672
|
+
dst = cgraph->nodes[node_idx + 1]; // silu
|
|
11673
|
+
} else if (ctx->num_additional_fused_ops == 2) {
|
|
11674
|
+
ggml_tensor * add = cgraph->nodes[node_idx + 1];
|
|
11675
|
+
bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
|
|
11676
|
+
dst = cgraph->nodes[node_idx + 2]; // silu
|
|
11677
|
+
}
|
|
11678
|
+
|
|
11679
|
+
// The shader always declares 4 bindings; bind src0 as a dummy when bias isn't fused.
|
|
11680
|
+
const ggml_tensor * src2 = bias ? bias : src0;
|
|
11681
|
+
|
|
11682
|
+
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SSM_CONV, {
|
|
10490
11683
|
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
|
|
10491
11684
|
(uint32_t)src1->nb[1],
|
|
10492
11685
|
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
|
|
@@ -10849,6 +12042,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
|
|
|
10849
12042
|
(uint32_t)src0->ne[2],
|
|
10850
12043
|
nb01, nb02, nb03,
|
|
10851
12044
|
nb11, nb12, nb13,
|
|
12045
|
+
0, 0, // a_offset, d_offset filled in by init_pushconst_tensor_offsets
|
|
10852
12046
|
};
|
|
10853
12047
|
|
|
10854
12048
|
return rope;
|
|
@@ -10944,6 +12138,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
10944
12138
|
GGML_ASSERT(buf[i] != nullptr);
|
|
10945
12139
|
}
|
|
10946
12140
|
|
|
12141
|
+
// a_offset is unused (the fused path reads from shared memory), but the rope/set_rows dst can be misaligned.
|
|
12142
|
+
// Round the binding offset down to the storage buffer alignment; the in-element shift goes in pc.rope.d_offset.
|
|
12143
|
+
pc.rope.d_offset = get_misalign_bytes(ctx, tensors[5]) / ggml_type_size(tensors[5]->type);
|
|
12144
|
+
offset[5] &= ~(size_t(ctx->device->properties.limits.minStorageBufferOffsetAlignment) - 1);
|
|
12145
|
+
|
|
10947
12146
|
std::array<uint32_t, 3> elements;
|
|
10948
12147
|
elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] };
|
|
10949
12148
|
|
|
@@ -11003,8 +12202,6 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
11003
12202
|
const float alpha = op_params_f[2];
|
|
11004
12203
|
const float limit = op_params_f[3];
|
|
11005
12204
|
|
|
11006
|
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
11007
|
-
|
|
11008
12205
|
if (!split) {
|
|
11009
12206
|
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
|
|
11010
12207
|
} else {
|
|
@@ -11022,7 +12219,17 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
11022
12219
|
(uint32_t)dst->ne[0],
|
|
11023
12220
|
mode,
|
|
11024
12221
|
alpha,
|
|
11025
|
-
limit
|
|
12222
|
+
limit,
|
|
12223
|
+
(uint32_t)(src0->nb[1] / src0->nb[0]),
|
|
12224
|
+
(uint32_t)(src0->nb[2] / src0->nb[0]),
|
|
12225
|
+
(uint32_t)(src0->nb[3] / src0->nb[0]),
|
|
12226
|
+
(uint32_t)src0->ne[1],
|
|
12227
|
+
(uint32_t)src0->ne[2],
|
|
12228
|
+
(uint32_t)(dst->nb[1] / dst->nb[0]),
|
|
12229
|
+
(uint32_t)(dst->nb[2] / dst->nb[0]),
|
|
12230
|
+
(uint32_t)(dst->nb[3] / dst->nb[0]),
|
|
12231
|
+
(uint32_t)dst->ne[1],
|
|
12232
|
+
(uint32_t)dst->ne[2]
|
|
11026
12233
|
});
|
|
11027
12234
|
}
|
|
11028
12235
|
|
|
@@ -11531,7 +12738,6 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
11531
12738
|
const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
|
11532
12739
|
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
|
11533
12740
|
|
|
11534
|
-
const uint32_t pelements = OW * KW * KH;
|
|
11535
12741
|
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
|
|
11536
12742
|
|
|
11537
12743
|
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
|
@@ -11543,7 +12749,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
11543
12749
|
dst_addr,
|
|
11544
12750
|
batch_offset, offset_delta,
|
|
11545
12751
|
IC, IW, IH, OW, OH, KW, KH,
|
|
11546
|
-
|
|
12752
|
+
OH * batch,
|
|
11547
12753
|
IC * KH * KW,
|
|
11548
12754
|
s0, s1, p0, p1, d0, d1, batch * IC
|
|
11549
12755
|
});
|
|
@@ -11656,6 +12862,45 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context&
|
|
|
11656
12862
|
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p));
|
|
11657
12863
|
}
|
|
11658
12864
|
|
|
12865
|
+
// Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b.
|
|
12866
|
+
// Match the naive mul -> sin -> sqr -> mul -> add chain and run the
|
|
12867
|
+
// dedicated kernel directly. The pattern is validated by
|
|
12868
|
+
// ggml_vk_can_fuse_snake before this call.
|
|
12869
|
+
static void ggml_vk_snake_dispatch_fused(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
|
|
12870
|
+
const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0];
|
|
12871
|
+
const ggml_tensor * sqr = cgraph->nodes[node_idx + 2];
|
|
12872
|
+
const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3];
|
|
12873
|
+
ggml_tensor * add = cgraph->nodes[node_idx + 4];
|
|
12874
|
+
|
|
12875
|
+
// x carries the full activation shape, a is the broadcast operand
|
|
12876
|
+
const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
|
|
12877
|
+
const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];
|
|
12878
|
+
|
|
12879
|
+
// mul1 reads sqr and inv_b in either operand order
|
|
12880
|
+
const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];
|
|
12881
|
+
|
|
12882
|
+
vk_pipeline pipeline = nullptr;
|
|
12883
|
+
switch (x->type) {
|
|
12884
|
+
case GGML_TYPE_F32: pipeline = ctx->device->pipeline_snake_f32; break;
|
|
12885
|
+
case GGML_TYPE_F16: pipeline = ctx->device->pipeline_snake_f16; break;
|
|
12886
|
+
case GGML_TYPE_BF16: pipeline = ctx->device->pipeline_snake_bf16; break;
|
|
12887
|
+
default: GGML_ABORT("unsupported type");
|
|
12888
|
+
}
|
|
12889
|
+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
12890
|
+
|
|
12891
|
+
vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x);
|
|
12892
|
+
vk_subbuffer a_buf = ggml_vk_tensor_subbuffer(ctx, a);
|
|
12893
|
+
vk_subbuffer inv_b_buf = ggml_vk_tensor_subbuffer(ctx, inv_b);
|
|
12894
|
+
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, add);
|
|
12895
|
+
|
|
12896
|
+
vk_op_snake_push_constants pc{};
|
|
12897
|
+
pc.ne0 = static_cast<uint32_t>(x->ne[0]);
|
|
12898
|
+
pc.ne1 = static_cast<uint32_t>(x->ne[1]);
|
|
12899
|
+
|
|
12900
|
+
std::array<uint32_t, 3> elements = { pc.ne0, pc.ne1, 1 };
|
|
12901
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { x_buf, a_buf, inv_b_buf, dst_buf }, pc, elements);
|
|
12902
|
+
}
|
|
12903
|
+
|
|
11659
12904
|
static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
11660
12905
|
uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
|
|
11661
12906
|
const int32_t k1 = dst->op_params[1];
|
|
@@ -12673,7 +13918,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
|
|
|
12673
13918
|
ggml_vk_destroy_buffer(ctx->prealloc_y);
|
|
12674
13919
|
}
|
|
12675
13920
|
ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
|
|
13921
|
+
ctx->prealloc_y_last_pipeline_used = nullptr;
|
|
12676
13922
|
ctx->prealloc_y_last_tensor_used = nullptr;
|
|
13923
|
+
ctx->prealloc_y_last_decode_vector_staging = false;
|
|
12677
13924
|
}
|
|
12678
13925
|
if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
|
|
12679
13926
|
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
|
|
@@ -12801,6 +14048,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
12801
14048
|
if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
|
|
12802
14049
|
ctx->query_node_idx[ctx->query_idx] = node_idx;
|
|
12803
14050
|
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
|
14051
|
+
ggml_vk_sync_buffers(ctx, compute_ctx);
|
|
12804
14052
|
}
|
|
12805
14053
|
}
|
|
12806
14054
|
// Add all fused nodes to the unsynchronized lists.
|
|
@@ -12863,7 +14111,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
12863
14111
|
|
|
12864
14112
|
break;
|
|
12865
14113
|
case GGML_OP_MUL:
|
|
12866
|
-
|
|
14114
|
+
if (ctx->num_additional_fused_ops) {
|
|
14115
|
+
ggml_vk_snake_dispatch_fused(ctx, compute_ctx, cgraph, node_idx);
|
|
14116
|
+
} else {
|
|
14117
|
+
ggml_vk_mul(ctx, compute_ctx, src0, src1, node);
|
|
14118
|
+
}
|
|
12867
14119
|
|
|
12868
14120
|
break;
|
|
12869
14121
|
case GGML_OP_DIV:
|
|
@@ -13153,7 +14405,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
13153
14405
|
break;
|
|
13154
14406
|
|
|
13155
14407
|
case GGML_OP_SSM_CONV:
|
|
13156
|
-
ggml_vk_ssm_conv(ctx, compute_ctx,
|
|
14408
|
+
ggml_vk_ssm_conv(ctx, compute_ctx, cgraph, node_idx);
|
|
13157
14409
|
|
|
13158
14410
|
break;
|
|
13159
14411
|
|
|
@@ -13248,6 +14500,8 @@ static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
13248
14500
|
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
|
|
13249
14501
|
VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
|
|
13250
14502
|
ctx->prealloc_y_last_pipeline_used = {};
|
|
14503
|
+
ctx->prealloc_y_last_tensor_used = nullptr;
|
|
14504
|
+
ctx->prealloc_y_last_decode_vector_staging = false;
|
|
13251
14505
|
|
|
13252
14506
|
ctx->unsynced_nodes_written.clear();
|
|
13253
14507
|
ctx->unsynced_nodes_read.clear();
|
|
@@ -13298,6 +14552,8 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
|
|
13298
14552
|
ggml_vk_destroy_buffer(ctx->sync_staging);
|
|
13299
14553
|
|
|
13300
14554
|
ctx->prealloc_y_last_pipeline_used = nullptr;
|
|
14555
|
+
ctx->prealloc_y_last_tensor_used = nullptr;
|
|
14556
|
+
ctx->prealloc_y_last_decode_vector_staging = false;
|
|
13301
14557
|
|
|
13302
14558
|
ctx->prealloc_size_x = 0;
|
|
13303
14559
|
ctx->prealloc_size_y = 0;
|
|
@@ -13401,6 +14657,20 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml
|
|
|
13401
14657
|
ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
|
|
13402
14658
|
}
|
|
13403
14659
|
|
|
14660
|
+
static void ggml_backend_vk_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset,
|
|
14661
|
+
size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
|
|
14662
|
+
VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " <<
|
|
14663
|
+
n_copies << ", " << stride_tensor << ", " << stride_data << ")");
|
|
14664
|
+
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
|
14665
|
+
vk_buffer buf = buf_ctx->dev_buffer;
|
|
14666
|
+
|
|
14667
|
+
if (size == 0) {
|
|
14668
|
+
return;
|
|
14669
|
+
}
|
|
14670
|
+
|
|
14671
|
+
ggml_vk_buffer_write_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_data, stride_tensor, size, n_copies);
|
|
14672
|
+
}
|
|
14673
|
+
|
|
13404
14674
|
static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
13405
14675
|
VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
|
|
13406
14676
|
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
|
@@ -13414,6 +14684,21 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons
|
|
|
13414
14684
|
ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
|
|
13415
14685
|
}
|
|
13416
14686
|
|
|
14687
|
+
static void ggml_backend_vk_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset,
|
|
14688
|
+
size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
|
|
14689
|
+
VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " <<
|
|
14690
|
+
n_copies << ", " << stride_tensor << ", " << stride_data << ")");
|
|
14691
|
+
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
|
14692
|
+
|
|
14693
|
+
if (size == 0) {
|
|
14694
|
+
return;
|
|
14695
|
+
}
|
|
14696
|
+
|
|
14697
|
+
vk_buffer buf = buf_ctx->dev_buffer;
|
|
14698
|
+
|
|
14699
|
+
ggml_vk_buffer_read_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_tensor, stride_data, size, n_copies);
|
|
14700
|
+
}
|
|
14701
|
+
|
|
13417
14702
|
static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
|
13418
14703
|
if (ggml_nbytes(src) == 0) {
|
|
13419
14704
|
return true;
|
|
@@ -13448,6 +14733,8 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
|
|
|
13448
14733
|
/* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor,
|
|
13449
14734
|
/* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
|
|
13450
14735
|
/* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
|
|
14736
|
+
/* .set_tensor_2d = */ ggml_backend_vk_buffer_set_tensor_2d,
|
|
14737
|
+
/* .get_tensor_2d = */ ggml_backend_vk_buffer_get_tensor_2d,
|
|
13451
14738
|
/* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
|
|
13452
14739
|
/* .clear = */ ggml_backend_vk_buffer_clear,
|
|
13453
14740
|
/* .reset = */ NULL,
|
|
@@ -13510,12 +14797,6 @@ static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_ty
|
|
|
13510
14797
|
UNUSED(buft);
|
|
13511
14798
|
}
|
|
13512
14799
|
|
|
13513
|
-
static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) {
|
|
13514
|
-
return GGML_VK_NAME "_Host";
|
|
13515
|
-
|
|
13516
|
-
UNUSED(buffer);
|
|
13517
|
-
}
|
|
13518
|
-
|
|
13519
14800
|
static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
13520
14801
|
VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
|
|
13521
14802
|
ggml_vk_host_free(vk_instance.devices[0], buffer->context);
|
|
@@ -13603,8 +14884,9 @@ static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_b
|
|
|
13603
14884
|
return &ctx->device->buffer_type;
|
|
13604
14885
|
}
|
|
13605
14886
|
|
|
13606
|
-
static void
|
|
13607
|
-
|
|
14887
|
+
static void ggml_backend_vk_set_tensor_2d_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset,
|
|
14888
|
+
size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
|
|
14889
|
+
VK_LOG_DEBUG("ggml_backend_vk_set_tensor_2d_async(" << size << ", " << n_copies << ")");
|
|
13608
14890
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
13609
14891
|
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
|
|
13610
14892
|
|
|
@@ -13618,7 +14900,6 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
|
|
|
13618
14900
|
|
|
13619
14901
|
if (ctx->device->async_use_transfer_queue) {
|
|
13620
14902
|
if (ctx->transfer_ctx.expired()) {
|
|
13621
|
-
// Initialize new transfer context
|
|
13622
14903
|
cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
|
|
13623
14904
|
ctx->transfer_ctx = cpy_ctx;
|
|
13624
14905
|
ggml_vk_ctx_begin(ctx->device, cpy_ctx);
|
|
@@ -13633,25 +14914,48 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
|
|
|
13633
14914
|
|
|
13634
14915
|
auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
13635
14916
|
|
|
13636
|
-
bool ret =
|
|
14917
|
+
bool ret = ggml_vk_buffer_write_2d_async(cpy_ctx, buf, dst_offset, data, stride_data, stride_tensor, size, n_copies);
|
|
13637
14918
|
|
|
13638
14919
|
if (!ret) {
|
|
13639
|
-
|
|
14920
|
+
const size_t staging_size = size * n_copies;
|
|
14921
|
+
ggml_vk_ensure_sync_staging_buffer(ctx, staging_size);
|
|
13640
14922
|
ggml_vk_sync_buffers(nullptr, cpy_ctx);
|
|
13641
14923
|
|
|
13642
|
-
vk::BufferCopy
|
|
13643
|
-
|
|
13644
|
-
|
|
13645
|
-
|
|
14924
|
+
std::vector<vk::BufferCopy> slices(1);
|
|
14925
|
+
if (size == stride_tensor) {
|
|
14926
|
+
slices[0].srcOffset = 0;
|
|
14927
|
+
slices[0].dstOffset = dst_offset;
|
|
14928
|
+
slices[0].size = staging_size;
|
|
14929
|
+
} else {
|
|
14930
|
+
slices.resize(n_copies);
|
|
14931
|
+
for (size_t i = 0; i < n_copies; i++) {
|
|
14932
|
+
slices[i].srcOffset = i * size;
|
|
14933
|
+
slices[i].dstOffset = dst_offset + i * stride_tensor;
|
|
14934
|
+
slices[i].size = size;
|
|
14935
|
+
}
|
|
14936
|
+
}
|
|
14937
|
+
|
|
14938
|
+
cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, slices);
|
|
13646
14939
|
|
|
13647
|
-
|
|
13648
|
-
|
|
14940
|
+
if (size == stride_data) {
|
|
14941
|
+
deferred_memcpy(ctx->sync_staging->ptr, data, staging_size, &cpy_ctx->in_memcpys);
|
|
14942
|
+
} else {
|
|
14943
|
+
for (size_t i = 0; i < n_copies; i++) {
|
|
14944
|
+
deferred_memcpy((uint8_t *)ctx->sync_staging->ptr + i * size, (const uint8_t *)data + i * stride_data, size, &cpy_ctx->in_memcpys);
|
|
14945
|
+
}
|
|
14946
|
+
}
|
|
13649
14947
|
ggml_vk_synchronize(ctx);
|
|
13650
14948
|
}
|
|
13651
14949
|
}
|
|
13652
14950
|
|
|
13653
|
-
static void
|
|
13654
|
-
VK_LOG_DEBUG("
|
|
14951
|
+
static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
14952
|
+
VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")");
|
|
14953
|
+
ggml_backend_vk_set_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size);
|
|
14954
|
+
}
|
|
14955
|
+
|
|
14956
|
+
static void ggml_backend_vk_get_tensor_2d_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset,
|
|
14957
|
+
size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
|
|
14958
|
+
VK_LOG_DEBUG("ggml_backend_vk_get_tensor_2d_async(" << size << ", " << n_copies << ")");
|
|
13655
14959
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
13656
14960
|
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
|
|
13657
14961
|
|
|
@@ -13666,24 +14970,45 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
|
|
|
13666
14970
|
vk_buffer buf = buf_ctx->dev_buffer;
|
|
13667
14971
|
|
|
13668
14972
|
auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
13669
|
-
bool ret =
|
|
14973
|
+
bool ret = ggml_vk_buffer_read_2d_async(compute_ctx, buf, src_offset, data, stride_tensor, stride_data, size, n_copies);
|
|
13670
14974
|
|
|
13671
|
-
// If that failed, copy synchronously through a staging buffer
|
|
13672
14975
|
if (!ret) {
|
|
13673
|
-
|
|
14976
|
+
const size_t staging_size = size * n_copies;
|
|
14977
|
+
ggml_vk_ensure_sync_staging_buffer(ctx, staging_size);
|
|
13674
14978
|
ggml_vk_sync_buffers(nullptr, compute_ctx);
|
|
13675
14979
|
|
|
13676
|
-
vk::BufferCopy
|
|
13677
|
-
|
|
13678
|
-
|
|
13679
|
-
|
|
14980
|
+
std::vector<vk::BufferCopy> slices(1);
|
|
14981
|
+
if (size == stride_tensor) {
|
|
14982
|
+
slices[0].srcOffset = src_offset;
|
|
14983
|
+
slices[0].dstOffset = 0;
|
|
14984
|
+
slices[0].size = staging_size;
|
|
14985
|
+
} else {
|
|
14986
|
+
slices.resize(n_copies);
|
|
14987
|
+
for (size_t i = 0; i < n_copies; i++) {
|
|
14988
|
+
slices[i].srcOffset = src_offset + i * stride_tensor;
|
|
14989
|
+
slices[i].dstOffset = i * size;
|
|
14990
|
+
slices[i].size = size;
|
|
14991
|
+
}
|
|
14992
|
+
}
|
|
14993
|
+
|
|
14994
|
+
compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, slices);
|
|
13680
14995
|
|
|
13681
|
-
|
|
13682
|
-
|
|
14996
|
+
if (size == stride_data) {
|
|
14997
|
+
deferred_memcpy(data, ctx->sync_staging->ptr, staging_size, &compute_ctx->out_memcpys);
|
|
14998
|
+
} else {
|
|
14999
|
+
for (size_t i = 0; i < n_copies; i++) {
|
|
15000
|
+
deferred_memcpy((uint8_t *)data + i * stride_data, (const uint8_t *)ctx->sync_staging->ptr + i * size, size, &compute_ctx->out_memcpys);
|
|
15001
|
+
}
|
|
15002
|
+
}
|
|
13683
15003
|
ggml_vk_synchronize(ctx);
|
|
13684
15004
|
}
|
|
13685
15005
|
}
|
|
13686
15006
|
|
|
15007
|
+
static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
15008
|
+
VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")");
|
|
15009
|
+
ggml_backend_vk_get_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size);
|
|
15010
|
+
}
|
|
15011
|
+
|
|
13687
15012
|
static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
|
|
13688
15013
|
VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")");
|
|
13689
15014
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context;
|
|
@@ -13797,6 +15122,7 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
|
|
|
13797
15122
|
ctx->submit_pending = false;
|
|
13798
15123
|
if (cmd_buf) {
|
|
13799
15124
|
cmd_buf->in_use = false;
|
|
15125
|
+
cmd_buf->buf.reset();
|
|
13800
15126
|
}
|
|
13801
15127
|
}
|
|
13802
15128
|
|
|
@@ -13974,6 +15300,62 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
|
|
|
13974
15300
|
return true;
|
|
13975
15301
|
}
|
|
13976
15302
|
|
|
15303
|
+
// Match SSM_CONV + UNARY(SILU) or SSM_CONV + ADD + UNARY(SILU). num_extra is 1 or 2.
|
|
15304
|
+
static bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
|
15305
|
+
int node_idx, int num_extra) {
|
|
15306
|
+
const ggml_tensor * conv = cgraph->nodes[node_idx];
|
|
15307
|
+
if (conv->op != GGML_OP_SSM_CONV) {
|
|
15308
|
+
return false;
|
|
15309
|
+
}
|
|
15310
|
+
|
|
15311
|
+
const ggml_tensor * silu = nullptr;
|
|
15312
|
+
const ggml_tensor * bias = nullptr;
|
|
15313
|
+
|
|
15314
|
+
if (num_extra == 1) {
|
|
15315
|
+
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_UNARY })) {
|
|
15316
|
+
return false;
|
|
15317
|
+
}
|
|
15318
|
+
silu = cgraph->nodes[node_idx + 1];
|
|
15319
|
+
} else if (num_extra == 2) {
|
|
15320
|
+
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY })) {
|
|
15321
|
+
return false;
|
|
15322
|
+
}
|
|
15323
|
+
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
|
|
15324
|
+
silu = cgraph->nodes[node_idx + 2];
|
|
15325
|
+
bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
|
|
15326
|
+
|
|
15327
|
+
if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) {
|
|
15328
|
+
return false;
|
|
15329
|
+
}
|
|
15330
|
+
// bias must be channel-wise (one element per channel of the conv output)
|
|
15331
|
+
if (ggml_nelements(bias) != conv->ne[0] || bias->ne[0] != conv->ne[0]) {
|
|
15332
|
+
return false;
|
|
15333
|
+
}
|
|
15334
|
+
if (add->type != GGML_TYPE_F32) {
|
|
15335
|
+
return false;
|
|
15336
|
+
}
|
|
15337
|
+
// The shader doesn't apply per-tensor offsets, so reject misaligned bias.
|
|
15338
|
+
if (get_misalign_bytes(ctx, bias) != 0) {
|
|
15339
|
+
return false;
|
|
15340
|
+
}
|
|
15341
|
+
} else {
|
|
15342
|
+
return false;
|
|
15343
|
+
}
|
|
15344
|
+
|
|
15345
|
+
if (ggml_get_unary_op(silu) != GGML_UNARY_OP_SILU) {
|
|
15346
|
+
return false;
|
|
15347
|
+
}
|
|
15348
|
+
if (conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
|
|
15349
|
+
return false;
|
|
15350
|
+
}
|
|
15351
|
+
// The shader writes to the fused dst using its own strides, but the push constants don't
|
|
15352
|
+
// carry a per-tensor offset, so the binding must be naturally aligned.
|
|
15353
|
+
if (get_misalign_bytes(ctx, silu) != 0) {
|
|
15354
|
+
return false;
|
|
15355
|
+
}
|
|
15356
|
+
return true;
|
|
15357
|
+
}
|
|
15358
|
+
|
|
13977
15359
|
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
|
13978
15360
|
int node_idx, topk_moe_mode mode) {
|
|
13979
15361
|
|
|
@@ -14104,6 +15486,65 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
|
|
|
14104
15486
|
return true;
|
|
14105
15487
|
}
|
|
14106
15488
|
|
|
15489
|
+
// Pattern check for the 5-op Snake fusion: mul -> sin -> sqr -> mul -> add.
|
|
15490
|
+
// Verifies the chain shape, the closure x_in_add == x_in_mul0, and that
|
|
15491
|
+
// the broadcast operands a and inv_b share a [1, C] layout.
|
|
15492
|
+
static bool ggml_vk_can_fuse_snake(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
|
15493
|
+
GGML_UNUSED(ctx);
|
|
15494
|
+
if (!ggml_can_fuse(cgraph, node_idx, snake_pattern)) {
|
|
15495
|
+
return false;
|
|
15496
|
+
}
|
|
15497
|
+
|
|
15498
|
+
const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0];
|
|
15499
|
+
const ggml_tensor * sin_node = cgraph->nodes[node_idx + 1];
|
|
15500
|
+
const ggml_tensor * sqr = cgraph->nodes[node_idx + 2];
|
|
15501
|
+
const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3];
|
|
15502
|
+
const ggml_tensor * add = cgraph->nodes[node_idx + 4];
|
|
15503
|
+
|
|
15504
|
+
const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
|
|
15505
|
+
const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];
|
|
15506
|
+
|
|
15507
|
+
const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];
|
|
15508
|
+
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
|
|
15509
|
+
|
|
15510
|
+
if (x_in_add != x) {
|
|
15511
|
+
return false;
|
|
15512
|
+
}
|
|
15513
|
+
if (x->type != GGML_TYPE_F32 && x->type != GGML_TYPE_F16 && x->type != GGML_TYPE_BF16) {
|
|
15514
|
+
return false;
|
|
15515
|
+
}
|
|
15516
|
+
// Shader bindings: data_a is A_TYPE so it follows x's precision, while
|
|
15517
|
+
// data_b and data_c are hardcoded float, so the broadcast operands must
|
|
15518
|
+
// be F32 regardless of x's type.
|
|
15519
|
+
if (a->type != GGML_TYPE_F32) return false;
|
|
15520
|
+
if (inv_b->type != GGML_TYPE_F32) return false;
|
|
15521
|
+
// Chain intermediates and output share x's precision (single A_TYPE / D_TYPE pipeline).
|
|
15522
|
+
if (mul0->type != x->type) return false;
|
|
15523
|
+
if (sin_node->type != x->type) return false;
|
|
15524
|
+
if (sqr->type != x->type) return false;
|
|
15525
|
+
if (mul1->type != x->type) return false;
|
|
15526
|
+
if (add->type != x->type) return false;
|
|
15527
|
+
if (!ggml_are_same_shape(a, inv_b)) {
|
|
15528
|
+
return false;
|
|
15529
|
+
}
|
|
15530
|
+
if (a->ne[0] != 1 || a->ne[1] != x->ne[1]) {
|
|
15531
|
+
return false;
|
|
15532
|
+
}
|
|
15533
|
+
// Dispatch is 2D over (ne0, ne1), so x and add must be 2D and a / inv_b
|
|
15534
|
+
// must collapse to [1, C, 1, 1]. Higher dims are not handled by the shader.
|
|
15535
|
+
if (x->ne[2] != 1 || x->ne[3] != 1) return false;
|
|
15536
|
+
if (add->ne[2] != 1 || add->ne[3] != 1) return false;
|
|
15537
|
+
if (a->ne[2] != 1 || a->ne[3] != 1) return false;
|
|
15538
|
+
if (inv_b->ne[2] != 1 || inv_b->ne[3] != 1) return false;
|
|
15539
|
+
// Shader uses idx = i0 + i1 * ne0 and reads data_b[i1] / data_c[i1],
|
|
15540
|
+
// so every operand must be contiguous.
|
|
15541
|
+
if (!ggml_is_contiguous(x) || !ggml_is_contiguous(add) ||
|
|
15542
|
+
!ggml_is_contiguous(a) || !ggml_is_contiguous(inv_b)) {
|
|
15543
|
+
return false;
|
|
15544
|
+
}
|
|
15545
|
+
return true;
|
|
15546
|
+
}
|
|
15547
|
+
|
|
14107
15548
|
// Check whether the tensors overlap in memory.
|
|
14108
15549
|
// Fusions can potentially overwrite src tensors in ways that are not prevented
|
|
14109
15550
|
// by ggml-alloc. If the fusion src is being applied in a way that's elementwise
|
|
@@ -14158,8 +15599,7 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co
|
|
|
14158
15599
|
}
|
|
14159
15600
|
|
|
14160
15601
|
// conditions for pipeline creation
|
|
14161
|
-
if (
|
|
14162
|
-
sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
|
|
15602
|
+
if (sizeof(vk_op_rms_norm_mul_rope_push_constants) > ctx->device->properties.limits.maxPushConstantsSize) {
|
|
14163
15603
|
return false;
|
|
14164
15604
|
}
|
|
14165
15605
|
|
|
@@ -14288,10 +15728,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
14288
15728
|
compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
|
14289
15729
|
ctx->query_idx = 0;
|
|
14290
15730
|
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
|
15731
|
+
ggml_vk_sync_buffers(ctx, compute_ctx);
|
|
14291
15732
|
}
|
|
14292
15733
|
|
|
14293
15734
|
ctx->prealloc_y_last_pipeline_used = nullptr;
|
|
14294
15735
|
ctx->prealloc_y_last_tensor_used = nullptr;
|
|
15736
|
+
ctx->prealloc_y_last_decode_vector_staging = false;
|
|
14295
15737
|
|
|
14296
15738
|
if (ctx->prealloc_size_add_rms_partials) {
|
|
14297
15739
|
ggml_vk_preallocate_buffers(ctx, nullptr);
|
|
@@ -14390,6 +15832,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
14390
15832
|
// they are overwritten, and one workgroup per row. So close enough.
|
|
14391
15833
|
op_srcs_fused_elementwise[0] = true;
|
|
14392
15834
|
op_srcs_fused_elementwise[1] = true;
|
|
15835
|
+
} else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 2)) {
|
|
15836
|
+
ctx->num_additional_fused_ops = 2;
|
|
15837
|
+
fusion_string = "SSM_CONV_BIAS_SILU";
|
|
15838
|
+
// ssm_conv reads multiple input tokens per output, so it's not elementwise w.r.t. its srcs.
|
|
15839
|
+
// The downstream add and silu are elementwise on the conv output.
|
|
15840
|
+
op_srcs_fused_elementwise[0] = false;
|
|
15841
|
+
op_srcs_fused_elementwise[1] = true;
|
|
15842
|
+
op_srcs_fused_elementwise[2] = true;
|
|
15843
|
+
} else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 1)) {
|
|
15844
|
+
ctx->num_additional_fused_ops = 1;
|
|
15845
|
+
fusion_string = "SSM_CONV_SILU";
|
|
15846
|
+
op_srcs_fused_elementwise[0] = false;
|
|
15847
|
+
op_srcs_fused_elementwise[1] = true;
|
|
14393
15848
|
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
|
|
14394
15849
|
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
|
|
14395
15850
|
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
|
|
@@ -14398,6 +15853,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
14398
15853
|
op_srcs_fused_elementwise[0] = false;
|
|
14399
15854
|
op_srcs_fused_elementwise[1] = false;
|
|
14400
15855
|
op_srcs_fused_elementwise[2] = false;
|
|
15856
|
+
} else if (ggml_vk_can_fuse_snake(ctx, cgraph, i)) {
|
|
15857
|
+
ctx->num_additional_fused_ops = 4;
|
|
15858
|
+
fusion_string = "SNAKE";
|
|
15859
|
+
// elementwise=true: snake.comp is safe under exact aliasing because each
|
|
15860
|
+
// thread reads data_x[idx] into a register before writing data_d[idx]
|
|
15861
|
+
// with a data dependency on that register. The overlap check still
|
|
15862
|
+
// rejects partial overlaps (different base or size).
|
|
15863
|
+
std::fill_n(op_srcs_fused_elementwise, 5, true);
|
|
14401
15864
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
|
|
14402
15865
|
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
|
|
14403
15866
|
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
|
|
@@ -14524,6 +15987,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
14524
15987
|
ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
|
|
14525
15988
|
ctx->query_fusion_names[ctx->query_idx] = fusion_string;
|
|
14526
15989
|
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
|
15990
|
+
ggml_vk_sync_buffers(ctx, compute_ctx);
|
|
14527
15991
|
} else {
|
|
14528
15992
|
// track a fusion string and number of fused ops for the current node_idx
|
|
14529
15993
|
ctx->query_fusion_names[i] = fusion_string;
|
|
@@ -14687,6 +16151,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|
|
14687
16151
|
if (keep_pattern(topk_moe_late_softmax)) {
|
|
14688
16152
|
continue;
|
|
14689
16153
|
}
|
|
16154
|
+
if (keep_pattern(snake_pattern)) {
|
|
16155
|
+
continue;
|
|
16156
|
+
}
|
|
14690
16157
|
|
|
14691
16158
|
// First, grab the next unused node.
|
|
14692
16159
|
current_set.push_back(first_unused);
|
|
@@ -14709,7 +16176,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|
|
14709
16176
|
if (match_pattern(topk_moe_early_softmax_norm, j) ||
|
|
14710
16177
|
match_pattern(topk_moe_sigmoid_norm_bias, j) ||
|
|
14711
16178
|
match_pattern(topk_moe_early_softmax, j) ||
|
|
14712
|
-
match_pattern(topk_moe_late_softmax, j)
|
|
16179
|
+
match_pattern(topk_moe_late_softmax, j) ||
|
|
16180
|
+
match_pattern(snake_pattern, j)) {
|
|
14713
16181
|
continue;
|
|
14714
16182
|
}
|
|
14715
16183
|
bool ok = true;
|
|
@@ -14720,7 +16188,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|
|
14720
16188
|
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
|
|
14721
16189
|
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
|
|
14722
16190
|
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
|
|
14723
|
-
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)
|
|
16191
|
+
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD) &&
|
|
16192
|
+
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_ADD) &&
|
|
16193
|
+
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_UNARY)) {
|
|
14724
16194
|
ok = false;
|
|
14725
16195
|
break;
|
|
14726
16196
|
}
|
|
@@ -14803,6 +16273,19 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|
|
14803
16273
|
}
|
|
14804
16274
|
}
|
|
14805
16275
|
}
|
|
16276
|
+
// SSM_CONV + ADD + UNARY: pull the consuming UNARY forward
|
|
16277
|
+
if (j > 0 &&
|
|
16278
|
+
graph->nodes[j]->op == GGML_OP_ADD &&
|
|
16279
|
+
graph->nodes[j-1]->op == GGML_OP_SSM_CONV) {
|
|
16280
|
+
for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
|
|
16281
|
+
if (graph->nodes[k]->op == GGML_OP_UNARY &&
|
|
16282
|
+
graph->nodes[k]->src[0] == graph->nodes[j]) {
|
|
16283
|
+
current_set.push_back(k);
|
|
16284
|
+
used[k] = true;
|
|
16285
|
+
break;
|
|
16286
|
+
}
|
|
16287
|
+
}
|
|
16288
|
+
}
|
|
14806
16289
|
}
|
|
14807
16290
|
}
|
|
14808
16291
|
// Second pass grabs view nodes.
|
|
@@ -14858,18 +16341,31 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
|
|
|
14858
16341
|
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
|
14859
16342
|
auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset
|
|
14860
16343
|
|
|
14861
|
-
|
|
14862
|
-
|
|
14863
|
-
|
|
14864
|
-
|
|
16344
|
+
if (vkev->has_event) {
|
|
16345
|
+
// Move existing event into submitted
|
|
16346
|
+
vkev->events_submitted.push_back(vkev->event);
|
|
16347
|
+
}
|
|
16348
|
+
|
|
16349
|
+
// Grab the next event and record it, create one if necessary
|
|
16350
|
+
if (vkev->events_free.empty()) {
|
|
16351
|
+
vkev->event = ctx->device->device.createEvent({});
|
|
16352
|
+
} else {
|
|
16353
|
+
vkev->event = vkev->events_free.back();
|
|
16354
|
+
vkev->events_free.pop_back();
|
|
16355
|
+
}
|
|
16356
|
+
|
|
16357
|
+
vkev->has_event = true;
|
|
14865
16358
|
|
|
14866
16359
|
ggml_vk_set_event(compute_ctx, vkev->event);
|
|
14867
16360
|
|
|
16361
|
+
vkev->tl_semaphore.value++;
|
|
16362
|
+
compute_ctx->s->signal_semaphores.push_back(vkev->tl_semaphore);
|
|
14868
16363
|
ggml_vk_ctx_end(compute_ctx);
|
|
14869
16364
|
|
|
14870
|
-
ggml_vk_submit(compute_ctx, {
|
|
16365
|
+
ggml_vk_submit(compute_ctx, {});
|
|
14871
16366
|
ctx->submit_pending = true;
|
|
14872
16367
|
vkev->cmd_buffer = cmd_buf;
|
|
16368
|
+
vkev->cmd_buffer_use_counter = cmd_buf->use_counter;
|
|
14873
16369
|
ctx->compute_ctx.reset();
|
|
14874
16370
|
}
|
|
14875
16371
|
|
|
@@ -14880,9 +16376,10 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even
|
|
|
14880
16376
|
|
|
14881
16377
|
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
|
14882
16378
|
|
|
14883
|
-
|
|
14884
|
-
|
|
14885
|
-
|
|
16379
|
+
if (vkev->has_event) {
|
|
16380
|
+
// Wait for latest event
|
|
16381
|
+
ggml_vk_wait_events(compute_ctx, { vkev->event });
|
|
16382
|
+
}
|
|
14886
16383
|
}
|
|
14887
16384
|
|
|
14888
16385
|
// TODO: enable async and synchronize
|
|
@@ -14891,6 +16388,8 @@ static ggml_backend_i ggml_backend_vk_interface = {
|
|
|
14891
16388
|
/* .free = */ ggml_backend_vk_free,
|
|
14892
16389
|
/* .set_tensor_async = */ ggml_backend_vk_set_tensor_async,
|
|
14893
16390
|
/* .get_tensor_async = */ ggml_backend_vk_get_tensor_async,
|
|
16391
|
+
/* .set_tensor_2d_async = */ ggml_backend_vk_set_tensor_2d_async,
|
|
16392
|
+
/* .get_tensor_2d_async = */ ggml_backend_vk_get_tensor_2d_async,
|
|
14894
16393
|
/* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async,
|
|
14895
16394
|
/* .synchronize = */ ggml_backend_vk_synchronize,
|
|
14896
16395
|
/* .graph_plan_create = */ NULL,
|
|
@@ -15157,8 +16656,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
15157
16656
|
case GGML_GLU_OP_SWIGLU_OAI:
|
|
15158
16657
|
case GGML_GLU_OP_GEGLU_ERF:
|
|
15159
16658
|
case GGML_GLU_OP_GEGLU_QUICK:
|
|
15160
|
-
return
|
|
15161
|
-
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
|
16659
|
+
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
|
15162
16660
|
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
|
15163
16661
|
(op->src[0]->type == op->type);
|
|
15164
16662
|
default:
|
|
@@ -15178,6 +16676,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
15178
16676
|
case GGML_TYPE_F32:
|
|
15179
16677
|
case GGML_TYPE_F16:
|
|
15180
16678
|
case GGML_TYPE_BF16:
|
|
16679
|
+
case GGML_TYPE_Q1_0:
|
|
15181
16680
|
case GGML_TYPE_Q4_0:
|
|
15182
16681
|
case GGML_TYPE_Q4_1:
|
|
15183
16682
|
case GGML_TYPE_Q5_0:
|
|
@@ -15198,6 +16697,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
15198
16697
|
case GGML_TYPE_IQ4_XS:
|
|
15199
16698
|
case GGML_TYPE_IQ4_NL:
|
|
15200
16699
|
case GGML_TYPE_MXFP4:
|
|
16700
|
+
case GGML_TYPE_NVFP4:
|
|
15201
16701
|
break;
|
|
15202
16702
|
default:
|
|
15203
16703
|
return false;
|
|
@@ -15246,42 +16746,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
15246
16746
|
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
|
15247
16747
|
return false;
|
|
15248
16748
|
}
|
|
15249
|
-
|
|
15250
|
-
|
|
15251
|
-
|
|
15252
|
-
|
|
15253
|
-
|
|
15254
|
-
|
|
15255
|
-
|
|
15256
|
-
|
|
15257
|
-
|
|
15258
|
-
|
|
15259
|
-
|
|
15260
|
-
|
|
15261
|
-
|
|
15262
|
-
|
|
15263
|
-
case GGML_TYPE_Q5_1:
|
|
15264
|
-
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
|
15265
|
-
//case GGML_TYPE_Q2_K:
|
|
15266
|
-
//case GGML_TYPE_Q3_K:
|
|
15267
|
-
//case GGML_TYPE_Q4_K:
|
|
15268
|
-
//case GGML_TYPE_Q5_K:
|
|
15269
|
-
//case GGML_TYPE_Q6_K:
|
|
15270
|
-
//case GGML_TYPE_IQ1_S:
|
|
15271
|
-
//case GGML_TYPE_IQ1_M:
|
|
15272
|
-
//case GGML_TYPE_IQ2_XXS:
|
|
15273
|
-
//case GGML_TYPE_IQ2_XS:
|
|
15274
|
-
//case GGML_TYPE_IQ2_S:
|
|
15275
|
-
//case GGML_TYPE_IQ3_XXS:
|
|
15276
|
-
//case GGML_TYPE_IQ3_S:
|
|
15277
|
-
//case GGML_TYPE_IQ4_XS:
|
|
15278
|
-
case GGML_TYPE_IQ4_NL:
|
|
15279
|
-
// currently supported only in coopmat2 path
|
|
15280
|
-
if (!coopmat2) {
|
|
16749
|
+
auto fa_kv_ok = [coopmat2](ggml_type t) {
|
|
16750
|
+
switch (t) {
|
|
16751
|
+
case GGML_TYPE_F32:
|
|
16752
|
+
case GGML_TYPE_F16:
|
|
16753
|
+
case GGML_TYPE_BF16:
|
|
16754
|
+
case GGML_TYPE_Q8_0:
|
|
16755
|
+
case GGML_TYPE_Q5_1:
|
|
16756
|
+
case GGML_TYPE_Q5_0:
|
|
16757
|
+
case GGML_TYPE_Q4_1:
|
|
16758
|
+
case GGML_TYPE_Q4_0:
|
|
16759
|
+
return true;
|
|
16760
|
+
case GGML_TYPE_Q1_0:
|
|
16761
|
+
return coopmat2;
|
|
16762
|
+
default:
|
|
15281
16763
|
return false;
|
|
15282
16764
|
}
|
|
15283
|
-
|
|
15284
|
-
|
|
16765
|
+
};
|
|
16766
|
+
if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) {
|
|
16767
|
+
return false;
|
|
16768
|
+
}
|
|
16769
|
+
if ((op->src[1]->type == GGML_TYPE_BF16) != (op->src[2]->type == GGML_TYPE_BF16)) {
|
|
15285
16770
|
return false;
|
|
15286
16771
|
}
|
|
15287
16772
|
if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
|
|
@@ -15296,6 +16781,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
15296
16781
|
case GGML_TYPE_F32:
|
|
15297
16782
|
case GGML_TYPE_F16:
|
|
15298
16783
|
case GGML_TYPE_BF16:
|
|
16784
|
+
case GGML_TYPE_Q1_0:
|
|
15299
16785
|
case GGML_TYPE_Q4_0:
|
|
15300
16786
|
case GGML_TYPE_Q4_1:
|
|
15301
16787
|
case GGML_TYPE_Q5_0:
|
|
@@ -15316,6 +16802,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
15316
16802
|
case GGML_TYPE_IQ4_XS:
|
|
15317
16803
|
case GGML_TYPE_IQ4_NL:
|
|
15318
16804
|
case GGML_TYPE_MXFP4:
|
|
16805
|
+
case GGML_TYPE_NVFP4:
|
|
15319
16806
|
case GGML_TYPE_I32:
|
|
15320
16807
|
return true;
|
|
15321
16808
|
default:
|
|
@@ -15328,6 +16815,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
15328
16815
|
case GGML_TYPE_F32:
|
|
15329
16816
|
case GGML_TYPE_F16:
|
|
15330
16817
|
case GGML_TYPE_BF16:
|
|
16818
|
+
case GGML_TYPE_Q1_0:
|
|
15331
16819
|
case GGML_TYPE_Q4_0:
|
|
15332
16820
|
case GGML_TYPE_Q4_1:
|
|
15333
16821
|
case GGML_TYPE_Q5_0:
|
|
@@ -15351,6 +16839,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
15351
16839
|
case GGML_TYPE_F32:
|
|
15352
16840
|
case GGML_TYPE_F16:
|
|
15353
16841
|
case GGML_TYPE_BF16:
|
|
16842
|
+
case GGML_TYPE_Q1_0:
|
|
15354
16843
|
case GGML_TYPE_Q4_0:
|
|
15355
16844
|
case GGML_TYPE_Q4_1:
|
|
15356
16845
|
case GGML_TYPE_Q5_0:
|
|
@@ -15365,6 +16854,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
15365
16854
|
if (src1_type == GGML_TYPE_F32) {
|
|
15366
16855
|
switch (src0_type) {
|
|
15367
16856
|
case GGML_TYPE_F16:
|
|
16857
|
+
case GGML_TYPE_BF16:
|
|
16858
|
+
case GGML_TYPE_Q1_0:
|
|
15368
16859
|
case GGML_TYPE_Q4_0:
|
|
15369
16860
|
case GGML_TYPE_Q4_1:
|
|
15370
16861
|
case GGML_TYPE_Q5_0:
|
|
@@ -15400,7 +16891,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
15400
16891
|
return false;
|
|
15401
16892
|
}
|
|
15402
16893
|
case GGML_OP_REPEAT:
|
|
15403
|
-
return ggml_type_size(op->type) ==
|
|
16894
|
+
return ggml_type_size(op->type) == ggml_type_size(op->src[0]->type) &&
|
|
16895
|
+
(ggml_type_size(op->type) == sizeof(float) || ggml_type_size(op->type) == 2);
|
|
15404
16896
|
case GGML_OP_REPEAT_BACK:
|
|
15405
16897
|
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
|
15406
16898
|
case GGML_OP_ROPE:
|
|
@@ -15492,8 +16984,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
15492
16984
|
|| (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32)
|
|
15493
16985
|
|| (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16);
|
|
15494
16986
|
case GGML_OP_ARANGE:
|
|
15495
|
-
case GGML_OP_FILL:
|
|
15496
16987
|
return op->type == GGML_TYPE_F32;
|
|
16988
|
+
case GGML_OP_FILL:
|
|
16989
|
+
return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
|
|
15497
16990
|
case GGML_OP_SCALE:
|
|
15498
16991
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
|
15499
16992
|
case GGML_OP_PAD:
|
|
@@ -15672,10 +17165,13 @@ static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t
|
|
|
15672
17165
|
return nullptr;
|
|
15673
17166
|
}
|
|
15674
17167
|
|
|
15675
|
-
//
|
|
15676
|
-
vkev->
|
|
15677
|
-
|
|
15678
|
-
|
|
17168
|
+
// No events initially, they get created on demand
|
|
17169
|
+
vkev->has_event = false;
|
|
17170
|
+
|
|
17171
|
+
vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
|
|
17172
|
+
vk::SemaphoreCreateInfo ci{};
|
|
17173
|
+
ci.setPNext(&tci);
|
|
17174
|
+
vkev->tl_semaphore = { device->device.createSemaphore(ci), 0 };
|
|
15679
17175
|
|
|
15680
17176
|
return new ggml_backend_event {
|
|
15681
17177
|
/* .device = */ dev,
|
|
@@ -15689,8 +17185,16 @@ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backe
|
|
|
15689
17185
|
|
|
15690
17186
|
vk_event *vkev = (vk_event *)event->context;
|
|
15691
17187
|
|
|
15692
|
-
device->device.
|
|
15693
|
-
|
|
17188
|
+
device->device.destroySemaphore(vkev->tl_semaphore.s);
|
|
17189
|
+
for (auto& event : vkev->events_free) {
|
|
17190
|
+
device->device.destroyEvent(event);
|
|
17191
|
+
}
|
|
17192
|
+
for (auto& event : vkev->events_submitted) {
|
|
17193
|
+
device->device.destroyEvent(event);
|
|
17194
|
+
}
|
|
17195
|
+
if (vkev->has_event) {
|
|
17196
|
+
device->device.destroyEvent(vkev->event);
|
|
17197
|
+
}
|
|
15694
17198
|
delete vkev;
|
|
15695
17199
|
delete event;
|
|
15696
17200
|
}
|
|
@@ -15701,10 +17205,29 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm
|
|
|
15701
17205
|
auto device = ggml_vk_get_device(ctx->device);
|
|
15702
17206
|
vk_event *vkev = (vk_event *)event->context;
|
|
15703
17207
|
|
|
15704
|
-
|
|
15705
|
-
|
|
15706
|
-
|
|
15707
|
-
|
|
17208
|
+
// Only do something if the event has actually been used
|
|
17209
|
+
if (vkev->has_event) {
|
|
17210
|
+
vk::Semaphore sem = vkev->tl_semaphore.s;
|
|
17211
|
+
uint64_t val = vkev->tl_semaphore.value;
|
|
17212
|
+
vk::SemaphoreWaitInfo swi{vk::SemaphoreWaitFlags{}, sem, val};
|
|
17213
|
+
VK_CHECK(device->device.waitSemaphores(swi, UINT64_MAX), "event_synchronize");
|
|
17214
|
+
|
|
17215
|
+
// Reset and move submitted events
|
|
17216
|
+
for (auto& event : vkev->events_submitted) {
|
|
17217
|
+
device->device.resetEvent(event);
|
|
17218
|
+
}
|
|
17219
|
+
vkev->events_free.insert(vkev->events_free.end(), vkev->events_submitted.begin(), vkev->events_submitted.end());
|
|
17220
|
+
vkev->events_submitted.clear();
|
|
17221
|
+
|
|
17222
|
+
// Finished using current command buffer so we flag for reuse
|
|
17223
|
+
if (vkev->cmd_buffer) {
|
|
17224
|
+
// Only flag for reuse if it hasn't been reused already
|
|
17225
|
+
if (vkev->cmd_buffer_use_counter == vkev->cmd_buffer->use_counter) {
|
|
17226
|
+
vkev->cmd_buffer->in_use = false;
|
|
17227
|
+
vkev->cmd_buffer->buf.reset();
|
|
17228
|
+
}
|
|
17229
|
+
vkev->cmd_buffer = nullptr;
|
|
17230
|
+
}
|
|
15708
17231
|
}
|
|
15709
17232
|
}
|
|
15710
17233
|
|
|
@@ -15958,6 +17481,7 @@ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev)
|
|
|
15958
17481
|
case 0xE20C: // B570
|
|
15959
17482
|
return 18;
|
|
15960
17483
|
case 0xE20B: // B580
|
|
17484
|
+
case 0xE211: // Pro B60
|
|
15961
17485
|
return 20;
|
|
15962
17486
|
default:
|
|
15963
17487
|
return 0;
|
|
@@ -16450,7 +17974,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
16450
17974
|
src_clone[4], src_clone[5], src_clone[6]);
|
|
16451
17975
|
} else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
|
|
16452
17976
|
tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
|
|
16453
|
-
src_clone[2], src_clone[3], src_clone[4], src_clone[5]
|
|
17977
|
+
src_clone[2], src_clone[3], src_clone[4], src_clone[5],
|
|
17978
|
+
ggml_get_op_params_i32(tensor, 0));
|
|
16454
17979
|
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
|
16455
17980
|
src_clone[0]->flags = tensor->src[0]->flags;
|
|
16456
17981
|
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|