whispercpp 1.3.6 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/README.md +38 -5
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -8
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +36 -42
- data/ext/ruby_whisper.h +135 -0
- data/ext/ruby_whisper_context.c +107 -28
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -65
- data/ext/ruby_whisper_segment.c +6 -6
- data/ext/ruby_whisper_transcribe.cpp +42 -15
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +1 -1
- data/ext/sources/examples/cli/cli.cpp +43 -9
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +199 -163
- data/ext/sources/ggml/CMakeLists.txt +21 -13
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +72 -10
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-rpc.h +3 -3
- data/ext/sources/ggml/include/ggml.h +101 -9
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +22 -5
- data/ext/sources/ggml/src/ggml-alloc.c +5 -1
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
- data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
- data/ext/sources/ggml/src/ggml-impl.h +6 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
- data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +289 -114
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
- data/ext/sources/ggml/src/ggml.c +110 -28
- data/ext/sources/ggml/src/gguf.cpp +173 -28
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +56 -12
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +411 -62
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +24 -6
- data/whispercpp.gemspec +2 -2
- metadata +215 -281
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
- data/ext/sources/examples/talk-llama/llama-context.h +0 -359
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
- data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
- data/ext/sources/examples/talk-llama/llama-model.h +0 -597
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
- data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
- data/ext/sources/examples/talk-llama/llama.h +0 -1573
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -704
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
- /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
#ifndef GGML_WEBGPU_SHADER_LIB_HPP
|
|
2
2
|
#define GGML_WEBGPU_SHADER_LIB_HPP
|
|
3
3
|
|
|
4
|
+
#include "ggml-impl.h"
|
|
4
5
|
#include "ggml-wgsl-shaders.hpp"
|
|
5
6
|
#include "ggml.h"
|
|
6
7
|
#include "pre_wgsl.hpp"
|
|
@@ -17,6 +18,9 @@
|
|
|
17
18
|
#define GGML_WEBGPU_F32_SIZE_BYTES 4
|
|
18
19
|
#define GGML_WEBGPU_I32_SIZE_BYTES 4
|
|
19
20
|
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
|
|
21
|
+
#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN 20u
|
|
22
|
+
#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE 32u
|
|
23
|
+
#define GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE 64u
|
|
20
24
|
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
|
|
21
25
|
// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
|
|
22
26
|
#define GGML_WEBGPU_KV_SEQ_PAD 256u
|
|
@@ -26,38 +30,32 @@
|
|
|
26
30
|
// Matrix multiplication parameters
|
|
27
31
|
|
|
28
32
|
// Register tiling parameters
|
|
29
|
-
#define WEBGPU_MUL_MAT_TILE_M
|
|
30
|
-
#define WEBGPU_MUL_MAT_TILE_N
|
|
31
|
-
#define WEBGPU_MUL_MAT_WG_SIZE_M
|
|
32
|
-
#define WEBGPU_MUL_MAT_WG_SIZE_N
|
|
33
|
-
#define
|
|
33
|
+
#define WEBGPU_MUL_MAT_TILE_M 4
|
|
34
|
+
#define WEBGPU_MUL_MAT_TILE_N 4
|
|
35
|
+
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
|
|
36
|
+
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
|
|
37
|
+
#define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8
|
|
38
|
+
#define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32
|
|
34
39
|
|
|
35
40
|
// Subgroup matrix parameters
|
|
36
41
|
// The number of subgroups in the M dimension
|
|
37
|
-
#define WEBGPU_MUL_MAT_SUBGROUP_M
|
|
42
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
|
|
38
43
|
// The number of subgroups in the N dimension
|
|
39
|
-
#define WEBGPU_MUL_MAT_SUBGROUP_N
|
|
44
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_N 4
|
|
40
45
|
// The number of subgroup matrices each subgroup accumulates over
|
|
41
|
-
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M
|
|
42
|
-
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N
|
|
46
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
|
|
47
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
|
48
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32
|
|
49
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32
|
|
43
50
|
|
|
44
51
|
// Matrix-vector multiplication parameters
|
|
45
52
|
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
|
46
53
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
#define
|
|
50
|
-
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
|
|
54
|
+
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4
|
|
55
|
+
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4
|
|
56
|
+
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4
|
|
51
57
|
|
|
52
|
-
|
|
53
|
-
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
|
|
54
|
-
|
|
55
|
-
// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
|
|
56
|
-
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
|
|
57
|
-
// Requires at least two (and multiple of 2) k-quant blocks per tile
|
|
58
|
-
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
|
|
59
|
-
|
|
60
|
-
// default size for legacy matrix multiplication
|
|
58
|
+
// default size for reg-tile matrix multiplication
|
|
61
59
|
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
|
62
60
|
|
|
63
61
|
// Same hash combine function as in boost
|
|
@@ -65,24 +63,41 @@ template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const
|
|
|
65
63
|
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
|
66
64
|
}
|
|
67
65
|
|
|
66
|
+
// Calculates base address of a tensor ignoring the fake base pointer
|
|
67
|
+
inline uintptr_t ggml_webgpu_tensor_addr(const ggml_tensor * tensor) {
|
|
68
|
+
const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor;
|
|
69
|
+
return (uintptr_t) base_tensor->data + tensor->view_offs;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
inline bool ggml_webgpu_tensor_equal(const ggml_tensor * a, const ggml_tensor * b) {
|
|
73
|
+
return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) == ggml_webgpu_tensor_addr(b);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
inline bool ggml_webgpu_tensor_overlap(const ggml_tensor * a, const ggml_tensor * b) {
|
|
77
|
+
return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) < ggml_webgpu_tensor_addr(b) + ggml_nbytes(b) &&
|
|
78
|
+
ggml_webgpu_tensor_addr(b) < ggml_webgpu_tensor_addr(a) + ggml_nbytes(a);
|
|
79
|
+
}
|
|
80
|
+
|
|
68
81
|
struct ggml_webgpu_shader_lib_context {
|
|
69
82
|
ggml_tensor * src0;
|
|
70
83
|
ggml_tensor * src1;
|
|
71
84
|
ggml_tensor * src2;
|
|
72
85
|
ggml_tensor * src3;
|
|
73
86
|
ggml_tensor * src4;
|
|
87
|
+
ggml_tensor * src5;
|
|
74
88
|
ggml_tensor * dst;
|
|
75
89
|
|
|
76
|
-
uint32_t
|
|
77
|
-
size_t
|
|
78
|
-
bool
|
|
79
|
-
bool
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
uint32_t
|
|
83
|
-
uint32_t
|
|
84
|
-
uint32_t
|
|
85
|
-
|
|
90
|
+
uint32_t max_wg_size;
|
|
91
|
+
size_t wg_mem_limit_bytes = 0;
|
|
92
|
+
bool supports_subgroups = false;
|
|
93
|
+
bool supports_subgroup_matrix = false;
|
|
94
|
+
uint32_t sg_mat_m = 0;
|
|
95
|
+
uint32_t sg_mat_n = 0;
|
|
96
|
+
uint32_t sg_mat_k = 0;
|
|
97
|
+
uint32_t min_subgroup_size = 0;
|
|
98
|
+
uint32_t max_subgroup_size = 0;
|
|
99
|
+
bool supports_dot_product = false;
|
|
100
|
+
std::string vendor;
|
|
86
101
|
};
|
|
87
102
|
|
|
88
103
|
struct webgpu_pipeline {
|
|
@@ -93,6 +108,51 @@ struct webgpu_pipeline {
|
|
|
93
108
|
|
|
94
109
|
struct ggml_webgpu_generic_shader_decisions {
|
|
95
110
|
uint32_t wg_size = 0;
|
|
111
|
+
bool inplace = false;
|
|
112
|
+
};
|
|
113
|
+
|
|
114
|
+
struct ggml_webgpu_binary_shader_decisions {
|
|
115
|
+
uint32_t wg_size = 0;
|
|
116
|
+
bool inplace = false;
|
|
117
|
+
bool overlap = false;
|
|
118
|
+
bool src_overlap = false;
|
|
119
|
+
};
|
|
120
|
+
|
|
121
|
+
struct ggml_webgpu_processed_shader {
|
|
122
|
+
std::string wgsl;
|
|
123
|
+
std::string variant;
|
|
124
|
+
std::shared_ptr<void> decisions;
|
|
125
|
+
};
|
|
126
|
+
|
|
127
|
+
struct ggml_webgpu_ssm_conv_shader_decisions {
|
|
128
|
+
uint32_t block_size;
|
|
129
|
+
uint32_t tokens_per_wg;
|
|
130
|
+
};
|
|
131
|
+
|
|
132
|
+
struct ggml_webgpu_ssm_scan_pipeline_key {
|
|
133
|
+
int type;
|
|
134
|
+
int d_state;
|
|
135
|
+
bool xbc_overlap;
|
|
136
|
+
|
|
137
|
+
bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const {
|
|
138
|
+
return type == other.type && d_state == other.d_state && xbc_overlap == other.xbc_overlap;
|
|
139
|
+
}
|
|
140
|
+
};
|
|
141
|
+
|
|
142
|
+
struct ggml_webgpu_ssm_scan_pipeline_key_hash {
|
|
143
|
+
size_t operator()(const ggml_webgpu_ssm_scan_pipeline_key & key) const {
|
|
144
|
+
size_t seed = 0;
|
|
145
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
146
|
+
ggml_webgpu_hash_combine(seed, key.d_state);
|
|
147
|
+
ggml_webgpu_hash_combine(seed, key.xbc_overlap);
|
|
148
|
+
return seed;
|
|
149
|
+
}
|
|
150
|
+
};
|
|
151
|
+
|
|
152
|
+
struct ggml_webgpu_ssm_scan_shader_decisions {
|
|
153
|
+
uint32_t wg_size;
|
|
154
|
+
uint32_t tokens_per_tile;
|
|
155
|
+
bool xbc_overlap = false;
|
|
96
156
|
};
|
|
97
157
|
|
|
98
158
|
/** Argsort **/
|
|
@@ -109,9 +169,11 @@ struct ggml_webgpu_set_rows_pipeline_key {
|
|
|
109
169
|
int dst_type;
|
|
110
170
|
int vec4;
|
|
111
171
|
int i64_idx;
|
|
172
|
+
int pair_blocks;
|
|
112
173
|
|
|
113
174
|
bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
|
|
114
|
-
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx
|
|
175
|
+
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx &&
|
|
176
|
+
pair_blocks == other.pair_blocks;
|
|
115
177
|
}
|
|
116
178
|
};
|
|
117
179
|
|
|
@@ -121,6 +183,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash {
|
|
|
121
183
|
ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
122
184
|
ggml_webgpu_hash_combine(seed, key.vec4);
|
|
123
185
|
ggml_webgpu_hash_combine(seed, key.i64_idx);
|
|
186
|
+
ggml_webgpu_hash_combine(seed, key.pair_blocks);
|
|
124
187
|
return seed;
|
|
125
188
|
}
|
|
126
189
|
};
|
|
@@ -128,9 +191,30 @@ struct ggml_webgpu_set_rows_pipeline_key_hash {
|
|
|
128
191
|
struct ggml_webgpu_set_rows_shader_decisions {
|
|
129
192
|
bool vec4;
|
|
130
193
|
bool i64_idx;
|
|
194
|
+
bool pair_blocks;
|
|
131
195
|
uint32_t wg_size;
|
|
132
196
|
};
|
|
133
197
|
|
|
198
|
+
/** Set **/
|
|
199
|
+
|
|
200
|
+
struct ggml_webgpu_set_pipeline_key {
|
|
201
|
+
ggml_type type;
|
|
202
|
+
bool inplace;
|
|
203
|
+
|
|
204
|
+
bool operator==(const ggml_webgpu_set_pipeline_key & other) const {
|
|
205
|
+
return type == other.type && inplace == other.inplace;
|
|
206
|
+
}
|
|
207
|
+
};
|
|
208
|
+
|
|
209
|
+
struct ggml_webgpu_set_pipeline_key_hash {
|
|
210
|
+
size_t operator()(const ggml_webgpu_set_pipeline_key & key) const {
|
|
211
|
+
size_t seed = 0;
|
|
212
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
213
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
214
|
+
return seed;
|
|
215
|
+
}
|
|
216
|
+
};
|
|
217
|
+
|
|
134
218
|
/** Get Rows **/
|
|
135
219
|
|
|
136
220
|
struct ggml_webgpu_get_rows_pipeline_key {
|
|
@@ -151,6 +235,59 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
|
|
|
151
235
|
}
|
|
152
236
|
};
|
|
153
237
|
|
|
238
|
+
/** Row Norm **/
|
|
239
|
+
|
|
240
|
+
struct ggml_webgpu_row_norm_pipeline_key {
|
|
241
|
+
ggml_op op;
|
|
242
|
+
ggml_type src_type;
|
|
243
|
+
ggml_type dst_type;
|
|
244
|
+
bool inplace;
|
|
245
|
+
|
|
246
|
+
bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
|
|
247
|
+
return op == other.op && src_type == other.src_type && dst_type == other.dst_type && inplace == other.inplace;
|
|
248
|
+
}
|
|
249
|
+
};
|
|
250
|
+
|
|
251
|
+
struct ggml_webgpu_row_norm_pipeline_key_hash {
|
|
252
|
+
size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
|
|
253
|
+
size_t seed = 0;
|
|
254
|
+
ggml_webgpu_hash_combine(seed, key.op);
|
|
255
|
+
ggml_webgpu_hash_combine(seed, key.src_type);
|
|
256
|
+
ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
257
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
258
|
+
return seed;
|
|
259
|
+
}
|
|
260
|
+
};
|
|
261
|
+
|
|
262
|
+
/** RMS_NORM + MUL **/
|
|
263
|
+
|
|
264
|
+
struct ggml_webgpu_rms_norm_mul_pipeline_key {
|
|
265
|
+
bool inplace; // rn_src == dst
|
|
266
|
+
bool overlap; // mul_src == dst
|
|
267
|
+
bool src_overlap; // rn_src == mul_src
|
|
268
|
+
|
|
269
|
+
bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const {
|
|
270
|
+
return inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
|
|
271
|
+
}
|
|
272
|
+
};
|
|
273
|
+
|
|
274
|
+
struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
|
|
275
|
+
size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const {
|
|
276
|
+
size_t seed = 0;
|
|
277
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
278
|
+
ggml_webgpu_hash_combine(seed, key.overlap);
|
|
279
|
+
ggml_webgpu_hash_combine(seed, key.src_overlap);
|
|
280
|
+
return seed;
|
|
281
|
+
}
|
|
282
|
+
};
|
|
283
|
+
|
|
284
|
+
struct ggml_webgpu_rms_norm_mul_shader_decisions {
|
|
285
|
+
uint32_t wg_size = 0;
|
|
286
|
+
bool inplace = false;
|
|
287
|
+
bool overlap = false;
|
|
288
|
+
bool src_overlap = false;
|
|
289
|
+
};
|
|
290
|
+
|
|
154
291
|
/** Pad **/
|
|
155
292
|
struct ggml_webgpu_pad_pipeline_key {
|
|
156
293
|
bool circular;
|
|
@@ -166,6 +303,107 @@ struct ggml_webgpu_pad_pipeline_key_hash {
|
|
|
166
303
|
}
|
|
167
304
|
};
|
|
168
305
|
|
|
306
|
+
/** Solve Tri **/
|
|
307
|
+
struct ggml_webgpu_solve_tri_pipeline_key {
|
|
308
|
+
int type;
|
|
309
|
+
int n;
|
|
310
|
+
int k;
|
|
311
|
+
|
|
312
|
+
bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const {
|
|
313
|
+
return type == other.type && n == other.n && k == other.k;
|
|
314
|
+
}
|
|
315
|
+
};
|
|
316
|
+
|
|
317
|
+
struct ggml_webgpu_solve_tri_pipeline_key_hash {
|
|
318
|
+
size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const {
|
|
319
|
+
size_t seed = 0;
|
|
320
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
321
|
+
ggml_webgpu_hash_combine(seed, key.n);
|
|
322
|
+
ggml_webgpu_hash_combine(seed, key.k);
|
|
323
|
+
return seed;
|
|
324
|
+
}
|
|
325
|
+
};
|
|
326
|
+
|
|
327
|
+
/** SSM Conv **/
|
|
328
|
+
struct ggml_webgpu_ssm_conv_pipeline_key {
|
|
329
|
+
int type;
|
|
330
|
+
int vectorized;
|
|
331
|
+
|
|
332
|
+
bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const {
|
|
333
|
+
return type == other.type && vectorized == other.vectorized;
|
|
334
|
+
}
|
|
335
|
+
};
|
|
336
|
+
|
|
337
|
+
/** CONV 2D */
|
|
338
|
+
struct ggml_webgpu_conv2d_pipeline_key {
|
|
339
|
+
ggml_type weight_type;
|
|
340
|
+
ggml_type input_type;
|
|
341
|
+
ggml_type output_type;
|
|
342
|
+
|
|
343
|
+
bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const {
|
|
344
|
+
return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type;
|
|
345
|
+
}
|
|
346
|
+
};
|
|
347
|
+
|
|
348
|
+
struct ggml_webgpu_conv2d_pipeline_key_hash {
|
|
349
|
+
size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const {
|
|
350
|
+
size_t seed = 0;
|
|
351
|
+
ggml_webgpu_hash_combine(seed, key.weight_type);
|
|
352
|
+
ggml_webgpu_hash_combine(seed, key.input_type);
|
|
353
|
+
ggml_webgpu_hash_combine(seed, key.output_type);
|
|
354
|
+
return seed;
|
|
355
|
+
}
|
|
356
|
+
};
|
|
357
|
+
|
|
358
|
+
/** Im2Col **/
|
|
359
|
+
struct ggml_webgpu_im2col_pipeline_key {
|
|
360
|
+
ggml_type input_type;
|
|
361
|
+
ggml_type output_type;
|
|
362
|
+
|
|
363
|
+
bool operator==(const ggml_webgpu_im2col_pipeline_key & other) const {
|
|
364
|
+
return input_type == other.input_type && output_type == other.output_type;
|
|
365
|
+
}
|
|
366
|
+
};
|
|
367
|
+
|
|
368
|
+
struct ggml_webgpu_im2col_pipeline_key_hash {
|
|
369
|
+
size_t operator()(const ggml_webgpu_im2col_pipeline_key & key) const {
|
|
370
|
+
size_t seed = 0;
|
|
371
|
+
ggml_webgpu_hash_combine(seed, key.input_type);
|
|
372
|
+
ggml_webgpu_hash_combine(seed, key.output_type);
|
|
373
|
+
return seed;
|
|
374
|
+
}
|
|
375
|
+
};
|
|
376
|
+
|
|
377
|
+
/** Gated Delta Net **/
|
|
378
|
+
struct ggml_webgpu_gated_delta_net_pipeline_key {
|
|
379
|
+
int type;
|
|
380
|
+
int s_v;
|
|
381
|
+
int kda;
|
|
382
|
+
|
|
383
|
+
bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const {
|
|
384
|
+
return type == other.type && s_v == other.s_v && kda == other.kda;
|
|
385
|
+
}
|
|
386
|
+
};
|
|
387
|
+
|
|
388
|
+
struct ggml_webgpu_gated_delta_net_pipeline_key_hash {
|
|
389
|
+
size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const {
|
|
390
|
+
size_t seed = 0;
|
|
391
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
392
|
+
ggml_webgpu_hash_combine(seed, key.s_v);
|
|
393
|
+
ggml_webgpu_hash_combine(seed, key.kda);
|
|
394
|
+
return seed;
|
|
395
|
+
}
|
|
396
|
+
};
|
|
397
|
+
|
|
398
|
+
struct ggml_webgpu_ssm_conv_pipeline_key_hash {
|
|
399
|
+
size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const {
|
|
400
|
+
size_t seed = 0;
|
|
401
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
402
|
+
ggml_webgpu_hash_combine(seed, key.vectorized);
|
|
403
|
+
return seed;
|
|
404
|
+
}
|
|
405
|
+
};
|
|
406
|
+
|
|
169
407
|
/** Scale **/
|
|
170
408
|
|
|
171
409
|
struct ggml_webgpu_scale_pipeline_key {
|
|
@@ -182,18 +420,47 @@ struct ggml_webgpu_scale_pipeline_key_hash {
|
|
|
182
420
|
}
|
|
183
421
|
};
|
|
184
422
|
|
|
423
|
+
/** Upscale **/
|
|
424
|
+
|
|
425
|
+
struct ggml_webgpu_upscale_pipeline_key {
|
|
426
|
+
ggml_type input_type;
|
|
427
|
+
ggml_type output_type;
|
|
428
|
+
uint32_t base_mode;
|
|
429
|
+
bool antialias;
|
|
430
|
+
|
|
431
|
+
bool operator==(const ggml_webgpu_upscale_pipeline_key & other) const {
|
|
432
|
+
return input_type == other.input_type && output_type == other.output_type && base_mode == other.base_mode &&
|
|
433
|
+
antialias == other.antialias;
|
|
434
|
+
}
|
|
435
|
+
};
|
|
436
|
+
|
|
437
|
+
struct ggml_webgpu_upscale_pipeline_key_hash {
|
|
438
|
+
size_t operator()(const ggml_webgpu_upscale_pipeline_key & key) const {
|
|
439
|
+
size_t seed = 0;
|
|
440
|
+
ggml_webgpu_hash_combine(seed, key.input_type);
|
|
441
|
+
ggml_webgpu_hash_combine(seed, key.output_type);
|
|
442
|
+
ggml_webgpu_hash_combine(seed, key.base_mode);
|
|
443
|
+
ggml_webgpu_hash_combine(seed, key.antialias);
|
|
444
|
+
return seed;
|
|
445
|
+
}
|
|
446
|
+
};
|
|
447
|
+
|
|
185
448
|
/** Concat **/
|
|
186
449
|
|
|
187
450
|
struct ggml_webgpu_concat_pipeline_key {
|
|
188
|
-
int
|
|
451
|
+
int type;
|
|
452
|
+
bool src_overlap;
|
|
189
453
|
|
|
190
|
-
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const {
|
|
454
|
+
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const {
|
|
455
|
+
return type == other.type && src_overlap == other.src_overlap;
|
|
456
|
+
}
|
|
191
457
|
};
|
|
192
458
|
|
|
193
459
|
struct ggml_webgpu_concat_pipeline_key_hash {
|
|
194
460
|
size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
|
|
195
461
|
size_t seed = 0;
|
|
196
462
|
ggml_webgpu_hash_combine(seed, key.type);
|
|
463
|
+
ggml_webgpu_hash_combine(seed, key.src_overlap);
|
|
197
464
|
return seed;
|
|
198
465
|
}
|
|
199
466
|
};
|
|
@@ -241,16 +508,34 @@ struct ggml_webgpu_binary_pipeline_key_hash {
|
|
|
241
508
|
}
|
|
242
509
|
};
|
|
243
510
|
|
|
511
|
+
/* Add_Id */
|
|
512
|
+
|
|
513
|
+
struct ggml_webgpu_add_id_pipeline_key {
|
|
514
|
+
bool inplace;
|
|
515
|
+
|
|
516
|
+
bool operator==(const ggml_webgpu_add_id_pipeline_key & other) const { return inplace == other.inplace; }
|
|
517
|
+
};
|
|
518
|
+
|
|
519
|
+
struct ggml_webgpu_add_id_pipeline_key_hash {
|
|
520
|
+
size_t operator()(const ggml_webgpu_add_id_pipeline_key & key) const {
|
|
521
|
+
size_t seed = 0;
|
|
522
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
523
|
+
return seed;
|
|
524
|
+
}
|
|
525
|
+
};
|
|
526
|
+
|
|
244
527
|
/** Unary **/
|
|
245
528
|
|
|
246
529
|
struct ggml_webgpu_unary_pipeline_key {
|
|
247
|
-
int
|
|
248
|
-
int
|
|
249
|
-
bool
|
|
250
|
-
bool
|
|
530
|
+
int type;
|
|
531
|
+
int op;
|
|
532
|
+
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
|
|
533
|
+
bool inplace;
|
|
534
|
+
ggml_tri_type ttype; // only used for GGML_OP_TRI
|
|
251
535
|
|
|
252
536
|
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
|
|
253
|
-
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace
|
|
537
|
+
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace &&
|
|
538
|
+
ttype == other.ttype;
|
|
254
539
|
}
|
|
255
540
|
};
|
|
256
541
|
|
|
@@ -261,58 +546,285 @@ struct ggml_webgpu_unary_pipeline_key_hash {
|
|
|
261
546
|
ggml_webgpu_hash_combine(seed, key.op);
|
|
262
547
|
ggml_webgpu_hash_combine(seed, key.is_unary);
|
|
263
548
|
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
549
|
+
ggml_webgpu_hash_combine(seed, key.ttype);
|
|
264
550
|
return seed;
|
|
265
551
|
}
|
|
266
552
|
};
|
|
267
553
|
|
|
268
554
|
/** FlashAttention */
|
|
269
555
|
|
|
270
|
-
struct
|
|
271
|
-
ggml_type
|
|
556
|
+
struct ggml_webgpu_flash_attn_common_pipeline_key {
|
|
557
|
+
ggml_type q_type;
|
|
558
|
+
ggml_type k_type;
|
|
559
|
+
ggml_type v_type;
|
|
560
|
+
ggml_type dst_type;
|
|
272
561
|
uint32_t head_dim_qk;
|
|
273
562
|
uint32_t head_dim_v;
|
|
274
563
|
bool kv_direct;
|
|
564
|
+
bool kv_overlap;
|
|
275
565
|
bool has_mask;
|
|
276
566
|
bool has_sinks;
|
|
277
567
|
bool uses_logit_softcap;
|
|
278
568
|
|
|
569
|
+
bool operator==(const ggml_webgpu_flash_attn_common_pipeline_key & other) const {
|
|
570
|
+
return q_type == other.q_type && k_type == other.k_type && v_type == other.v_type &&
|
|
571
|
+
dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
|
572
|
+
kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask &&
|
|
573
|
+
has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap;
|
|
574
|
+
}
|
|
575
|
+
};
|
|
576
|
+
|
|
577
|
+
inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed,
|
|
578
|
+
const ggml_webgpu_flash_attn_common_pipeline_key & key) {
|
|
579
|
+
ggml_webgpu_hash_combine(seed, key.q_type);
|
|
580
|
+
ggml_webgpu_hash_combine(seed, key.k_type);
|
|
581
|
+
ggml_webgpu_hash_combine(seed, key.v_type);
|
|
582
|
+
ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
583
|
+
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
|
584
|
+
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
|
585
|
+
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
|
586
|
+
ggml_webgpu_hash_combine(seed, key.kv_overlap);
|
|
587
|
+
ggml_webgpu_hash_combine(seed, key.has_mask);
|
|
588
|
+
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
|
589
|
+
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
struct ggml_webgpu_flash_attn_vec_pipeline_key {
|
|
593
|
+
ggml_webgpu_flash_attn_common_pipeline_key common;
|
|
594
|
+
|
|
595
|
+
bool operator==(const ggml_webgpu_flash_attn_vec_pipeline_key & other) const { return common == other.common; }
|
|
596
|
+
};
|
|
597
|
+
|
|
598
|
+
struct ggml_webgpu_flash_attn_vec_pipeline_key_hash {
|
|
599
|
+
size_t operator()(const ggml_webgpu_flash_attn_vec_pipeline_key & key) const {
|
|
600
|
+
size_t seed = 0;
|
|
601
|
+
ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common);
|
|
602
|
+
return seed;
|
|
603
|
+
}
|
|
604
|
+
};
|
|
605
|
+
|
|
606
|
+
struct ggml_webgpu_flash_attn_pipeline_key {
|
|
607
|
+
ggml_webgpu_flash_attn_common_pipeline_key common;
|
|
608
|
+
bool use_sg_matrix;
|
|
609
|
+
|
|
279
610
|
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
|
280
|
-
return
|
|
281
|
-
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
|
282
|
-
uses_logit_softcap == other.uses_logit_softcap;
|
|
611
|
+
return common == other.common && use_sg_matrix == other.use_sg_matrix;
|
|
283
612
|
}
|
|
284
613
|
};
|
|
285
614
|
|
|
286
615
|
struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
|
287
616
|
size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
|
|
288
617
|
size_t seed = 0;
|
|
289
|
-
|
|
290
|
-
ggml_webgpu_hash_combine(seed, key.
|
|
618
|
+
ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common);
|
|
619
|
+
ggml_webgpu_hash_combine(seed, key.use_sg_matrix);
|
|
620
|
+
return seed;
|
|
621
|
+
}
|
|
622
|
+
};
|
|
623
|
+
|
|
624
|
+
struct ggml_webgpu_flash_attn_vec_decisions {
|
|
625
|
+
uint32_t kv_tile = 0;
|
|
626
|
+
uint32_t wg_size = 0;
|
|
627
|
+
};
|
|
628
|
+
|
|
629
|
+
struct ggml_webgpu_flash_attn_decisions {
|
|
630
|
+
bool use_sg_matrix = false;
|
|
631
|
+
uint32_t q_tile = 0;
|
|
632
|
+
uint32_t kv_tile = 0;
|
|
633
|
+
uint32_t wg_size = 0;
|
|
634
|
+
};
|
|
635
|
+
|
|
636
|
+
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
|
|
637
|
+
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u;
|
|
638
|
+
|
|
639
|
+
inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) {
|
|
640
|
+
constexpr uintptr_t ptr_base_addr = 0x1000u;
|
|
641
|
+
const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor;
|
|
642
|
+
return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs;
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) {
|
|
646
|
+
const uint32_t offset_elems =
|
|
647
|
+
(uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) /
|
|
648
|
+
ggml_type_size(K->type));
|
|
649
|
+
return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u;
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K,
|
|
653
|
+
const ggml_tensor * V,
|
|
654
|
+
size_t storage_offset_alignment) {
|
|
655
|
+
return ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment) &&
|
|
656
|
+
ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment);
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q,
|
|
660
|
+
const ggml_tensor * K,
|
|
661
|
+
const ggml_tensor * V,
|
|
662
|
+
uint32_t kv_direct_align) {
|
|
663
|
+
return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) &&
|
|
664
|
+
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_common_pipeline_key(
|
|
668
|
+
const ggml_webgpu_shader_lib_context & context,
|
|
669
|
+
uint32_t kv_direct_align) {
|
|
670
|
+
ggml_webgpu_flash_attn_common_pipeline_key key = {};
|
|
671
|
+
key.q_type = context.src0->type;
|
|
672
|
+
key.k_type = context.src1->type;
|
|
673
|
+
key.v_type = context.src2->type;
|
|
674
|
+
key.dst_type = context.dst->type;
|
|
675
|
+
key.head_dim_qk = (uint32_t) context.src0->ne[0];
|
|
676
|
+
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
|
677
|
+
key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
|
|
678
|
+
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
|
|
679
|
+
key.has_mask = context.src3 != nullptr;
|
|
680
|
+
key.has_sinks = context.src4 != nullptr;
|
|
681
|
+
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
|
|
682
|
+
return key;
|
|
683
|
+
}
|
|
684
|
+
|
|
685
|
+
inline std::vector<std::string> ggml_webgpu_flash_attn_common_defines(
|
|
686
|
+
const ggml_webgpu_flash_attn_common_pipeline_key & key,
|
|
687
|
+
std::string & variant,
|
|
688
|
+
uint32_t q_tile,
|
|
689
|
+
uint32_t kv_tile,
|
|
690
|
+
uint32_t wg_size) {
|
|
691
|
+
std::vector<std::string> defines;
|
|
692
|
+
|
|
693
|
+
switch (key.k_type) {
|
|
694
|
+
case GGML_TYPE_F32:
|
|
695
|
+
defines.push_back("K_F32");
|
|
696
|
+
break;
|
|
697
|
+
case GGML_TYPE_F16:
|
|
698
|
+
defines.push_back("K_F16");
|
|
699
|
+
break;
|
|
700
|
+
case GGML_TYPE_Q4_0:
|
|
701
|
+
defines.push_back("K_Q4_0");
|
|
702
|
+
break;
|
|
703
|
+
case GGML_TYPE_Q8_0:
|
|
704
|
+
defines.push_back("K_Q8_0");
|
|
705
|
+
break;
|
|
706
|
+
default:
|
|
707
|
+
GGML_ABORT("Unsupported K type for flash attention shader");
|
|
708
|
+
}
|
|
709
|
+
variant += std::string("_k") + ggml_type_name(key.k_type);
|
|
710
|
+
|
|
711
|
+
switch (key.v_type) {
|
|
712
|
+
case GGML_TYPE_F32:
|
|
713
|
+
defines.push_back("V_F32");
|
|
714
|
+
break;
|
|
715
|
+
case GGML_TYPE_F16:
|
|
716
|
+
defines.push_back("V_F16");
|
|
717
|
+
break;
|
|
718
|
+
case GGML_TYPE_Q4_0:
|
|
719
|
+
defines.push_back("V_Q4_0");
|
|
720
|
+
break;
|
|
721
|
+
case GGML_TYPE_Q8_0:
|
|
722
|
+
defines.push_back("V_Q8_0");
|
|
723
|
+
break;
|
|
724
|
+
default:
|
|
725
|
+
GGML_ABORT("Unsupported V type for flash attention shader");
|
|
726
|
+
}
|
|
727
|
+
variant += std::string("_v") + ggml_type_name(key.v_type);
|
|
728
|
+
|
|
729
|
+
switch (key.q_type) {
|
|
730
|
+
case GGML_TYPE_F32:
|
|
731
|
+
defines.push_back("Q_F32");
|
|
732
|
+
break;
|
|
733
|
+
case GGML_TYPE_F16:
|
|
734
|
+
defines.push_back("Q_F16");
|
|
735
|
+
break;
|
|
736
|
+
default:
|
|
737
|
+
GGML_ABORT("Unsupported Q type for flash attention shader");
|
|
738
|
+
}
|
|
739
|
+
variant += std::string("_q") + ggml_type_name(key.q_type);
|
|
740
|
+
|
|
741
|
+
switch (key.dst_type) {
|
|
742
|
+
case GGML_TYPE_F32:
|
|
743
|
+
defines.push_back("DST_F32");
|
|
744
|
+
break;
|
|
745
|
+
case GGML_TYPE_F16:
|
|
746
|
+
defines.push_back("DST_F16");
|
|
747
|
+
break;
|
|
748
|
+
default:
|
|
749
|
+
GGML_ABORT("Unsupported dst type for flash attention shader");
|
|
750
|
+
}
|
|
751
|
+
variant += std::string("_dst") + ggml_type_name(key.dst_type);
|
|
752
|
+
|
|
753
|
+
if (key.has_mask) {
|
|
754
|
+
defines.push_back("MASK");
|
|
755
|
+
variant += "_mask";
|
|
756
|
+
}
|
|
757
|
+
if (key.has_sinks) {
|
|
758
|
+
defines.push_back("SINKS");
|
|
759
|
+
variant += "_sinks";
|
|
760
|
+
}
|
|
761
|
+
if (key.uses_logit_softcap) {
|
|
762
|
+
defines.push_back("LOGIT_SOFTCAP");
|
|
763
|
+
variant += "_lgsc";
|
|
764
|
+
}
|
|
765
|
+
if (key.kv_direct) {
|
|
766
|
+
defines.push_back("KV_DIRECT");
|
|
767
|
+
variant += "_kvdirect";
|
|
768
|
+
}
|
|
769
|
+
if (key.kv_overlap) {
|
|
770
|
+
defines.push_back("KV_OVERLAP");
|
|
771
|
+
variant += "_kv_overlap";
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
|
775
|
+
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
|
776
|
+
|
|
777
|
+
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
|
778
|
+
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
|
779
|
+
|
|
780
|
+
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
|
|
781
|
+
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
|
|
782
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
783
|
+
|
|
784
|
+
if (ggml_is_quantized(key.k_type) || ggml_is_quantized(key.v_type)) {
|
|
785
|
+
defines.push_back("U32_DEQUANT_HELPERS");
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
return defines;
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
|
|
792
|
+
uint32_t head_dim_v;
|
|
793
|
+
uint32_t wg_size;
|
|
794
|
+
ggml_type dst_type;
|
|
795
|
+
};
|
|
796
|
+
|
|
797
|
+
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
|
|
798
|
+
size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const {
|
|
799
|
+
size_t seed = 0;
|
|
291
800
|
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
|
292
|
-
ggml_webgpu_hash_combine(seed, key.
|
|
293
|
-
ggml_webgpu_hash_combine(seed, key.
|
|
294
|
-
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
|
295
|
-
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
|
801
|
+
ggml_webgpu_hash_combine(seed, key.wg_size);
|
|
802
|
+
ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
296
803
|
return seed;
|
|
297
804
|
}
|
|
298
805
|
};
|
|
299
806
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
uint32_t
|
|
807
|
+
inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
|
|
808
|
+
const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
|
|
809
|
+
return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size && lhs.dst_type == rhs.dst_type;
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
struct ggml_webgpu_flash_attn_blk_pipeline_key {
|
|
813
|
+
uint32_t kv_tile;
|
|
814
|
+
|
|
815
|
+
bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return kv_tile == other.kv_tile; }
|
|
307
816
|
};
|
|
308
817
|
|
|
309
|
-
struct
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
818
|
+
struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
|
|
819
|
+
size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
|
|
820
|
+
size_t seed = 0;
|
|
821
|
+
ggml_webgpu_hash_combine(seed, key.kv_tile);
|
|
822
|
+
return seed;
|
|
823
|
+
}
|
|
313
824
|
};
|
|
314
825
|
|
|
315
|
-
//
|
|
826
|
+
// Note: this will slightly overestimate memory usage for vec path
|
|
827
|
+
// since row_max and exp_sum shmem are not needed.
|
|
316
828
|
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
|
317
829
|
uint32_t kv_tile,
|
|
318
830
|
uint32_t head_dim_qk,
|
|
@@ -322,47 +834,82 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
|
|
322
834
|
const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
|
|
323
835
|
size_t f16_elems = 0;
|
|
324
836
|
size_t f32_elems = 0;
|
|
325
|
-
|
|
837
|
+
|
|
838
|
+
f32_elems += q_tile * head_dim_qk; // q_shmem
|
|
326
839
|
if (!kv_direct) {
|
|
327
|
-
|
|
840
|
+
f32_elems += kv_tile * max_head_dim; // kv_shmem
|
|
328
841
|
}
|
|
329
|
-
|
|
842
|
+
f32_elems += q_tile * head_dim_v; // o_shmem
|
|
330
843
|
if (has_mask) {
|
|
331
|
-
|
|
844
|
+
f32_elems += q_tile * kv_tile; // mask_shmem
|
|
332
845
|
}
|
|
333
|
-
|
|
846
|
+
f32_elems += q_tile * kv_tile; // inter_shmem
|
|
334
847
|
f32_elems += q_tile; // row_max_shmem
|
|
335
848
|
f32_elems += q_tile; // exp_sum_shmem
|
|
336
849
|
return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
|
|
337
850
|
}
|
|
338
851
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
852
|
+
inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes,
|
|
853
|
+
uint32_t q_tile,
|
|
854
|
+
uint32_t kv_granularity,
|
|
855
|
+
uint32_t head_dim_qk,
|
|
856
|
+
uint32_t head_dim_v,
|
|
857
|
+
bool has_mask,
|
|
858
|
+
bool kv_direct) {
|
|
859
|
+
const size_t base_q_bytes =
|
|
860
|
+
ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, head_dim_qk, head_dim_v, has_mask, kv_direct);
|
|
861
|
+
if (limit_bytes <= base_q_bytes) {
|
|
862
|
+
return 0;
|
|
347
863
|
}
|
|
348
|
-
|
|
864
|
+
const size_t one_kv_bytes =
|
|
865
|
+
ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, head_dim_qk, head_dim_v, has_mask, kv_direct);
|
|
866
|
+
const size_t bytes_per_kv = one_kv_bytes - base_q_bytes;
|
|
867
|
+
if (bytes_per_kv == 0) {
|
|
868
|
+
return 0;
|
|
869
|
+
}
|
|
870
|
+
const size_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
|
871
|
+
return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity);
|
|
872
|
+
}
|
|
349
873
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
874
|
+
inline uint32_t ggml_webgpu_flash_attn_get_vec_kv_tile(size_t wg_mem_limit_bytes,
|
|
875
|
+
uint32_t head_dim_qk,
|
|
876
|
+
uint32_t head_dim_v,
|
|
877
|
+
bool has_mask,
|
|
878
|
+
bool kv_direct) {
|
|
879
|
+
const uint32_t max_kv_tile =
|
|
880
|
+
ggml_webgpu_flash_attn_max_kv_tile(wg_mem_limit_bytes, 1u, 1u, head_dim_qk, head_dim_v, has_mask, kv_direct);
|
|
881
|
+
GGML_ASSERT(max_kv_tile > 0);
|
|
882
|
+
|
|
883
|
+
uint32_t kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile);
|
|
884
|
+
if (kv_direct) {
|
|
885
|
+
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
|
886
|
+
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
|
887
|
+
kv_tile -= 1u;
|
|
888
|
+
}
|
|
356
889
|
}
|
|
357
|
-
|
|
890
|
+
|
|
891
|
+
return kv_tile;
|
|
892
|
+
}
|
|
893
|
+
|
|
894
|
+
inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix,
|
|
895
|
+
uint32_t sg_mat_k,
|
|
896
|
+
uint32_t sg_mat_n,
|
|
897
|
+
const ggml_tensor * Q,
|
|
898
|
+
const ggml_tensor * V) {
|
|
899
|
+
return supports_subgroup_matrix && Q->ne[0] % sg_mat_k == 0 && V->ne[0] % sg_mat_n == 0;
|
|
900
|
+
}
|
|
901
|
+
|
|
902
|
+
/** Matrix Multiplication **/
|
|
358
903
|
|
|
359
904
|
struct ggml_webgpu_mul_mat_vec_pipeline_key {
|
|
360
905
|
ggml_type src0_type;
|
|
361
906
|
ggml_type src1_type;
|
|
362
907
|
int vectorized;
|
|
908
|
+
bool use_mmvq;
|
|
363
909
|
|
|
364
910
|
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
|
|
365
|
-
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized
|
|
911
|
+
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
|
|
912
|
+
use_mmvq == other.use_mmvq;
|
|
366
913
|
}
|
|
367
914
|
};
|
|
368
915
|
|
|
@@ -372,17 +919,31 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
|
|
|
372
919
|
ggml_webgpu_hash_combine(seed, key.src0_type);
|
|
373
920
|
ggml_webgpu_hash_combine(seed, key.src1_type);
|
|
374
921
|
ggml_webgpu_hash_combine(seed, key.vectorized);
|
|
922
|
+
ggml_webgpu_hash_combine(seed, key.use_mmvq);
|
|
375
923
|
return seed;
|
|
376
924
|
}
|
|
377
925
|
};
|
|
378
926
|
|
|
379
927
|
struct ggml_webgpu_mul_mat_vec_shader_decisions {
|
|
380
928
|
uint32_t wg_size;
|
|
381
|
-
uint32_t tile_k;
|
|
382
929
|
uint32_t outputs_per_wg;
|
|
383
930
|
uint32_t vec_size;
|
|
384
931
|
};
|
|
385
932
|
|
|
933
|
+
struct ggml_webgpu_quantize_q8_pipeline_key {
|
|
934
|
+
ggml_type src0_type;
|
|
935
|
+
|
|
936
|
+
bool operator==(const ggml_webgpu_quantize_q8_pipeline_key & other) const { return src0_type == other.src0_type; }
|
|
937
|
+
};
|
|
938
|
+
|
|
939
|
+
struct ggml_webgpu_quantize_q8_pipeline_key_hash {
|
|
940
|
+
size_t operator()(const ggml_webgpu_quantize_q8_pipeline_key & key) const {
|
|
941
|
+
size_t seed = 0;
|
|
942
|
+
ggml_webgpu_hash_combine(seed, key.src0_type);
|
|
943
|
+
return seed;
|
|
944
|
+
}
|
|
945
|
+
};
|
|
946
|
+
|
|
386
947
|
struct ggml_webgpu_mul_mat_pipeline_key {
|
|
387
948
|
ggml_type src0_type;
|
|
388
949
|
ggml_type src1_type;
|
|
@@ -426,8 +987,152 @@ struct ggml_webgpu_mul_mat_shader_decisions {
|
|
|
426
987
|
uint32_t mul_mat_wg_size;
|
|
427
988
|
};
|
|
428
989
|
|
|
429
|
-
|
|
430
|
-
|
|
990
|
+
/** MUL_MAT_ID **/
|
|
991
|
+
|
|
992
|
+
struct ggml_webgpu_mul_mat_id_pipeline_key {
|
|
993
|
+
ggml_type src0_type;
|
|
994
|
+
ggml_type src1_type;
|
|
995
|
+
uint32_t n_experts;
|
|
996
|
+
int vectorized;
|
|
997
|
+
|
|
998
|
+
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
|
|
999
|
+
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
|
|
1000
|
+
vectorized == other.vectorized;
|
|
1001
|
+
}
|
|
1002
|
+
};
|
|
1003
|
+
|
|
1004
|
+
struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
|
|
1005
|
+
size_t operator()(const ggml_webgpu_mul_mat_id_pipeline_key & key) const {
|
|
1006
|
+
size_t seed = 0;
|
|
1007
|
+
ggml_webgpu_hash_combine(seed, key.src0_type);
|
|
1008
|
+
ggml_webgpu_hash_combine(seed, key.src1_type);
|
|
1009
|
+
ggml_webgpu_hash_combine(seed, key.n_experts);
|
|
1010
|
+
ggml_webgpu_hash_combine(seed, key.vectorized);
|
|
1011
|
+
return seed;
|
|
1012
|
+
}
|
|
1013
|
+
};
|
|
1014
|
+
|
|
1015
|
+
/** Cpy **/
|
|
1016
|
+
|
|
1017
|
+
struct ggml_webgpu_cpy_pipeline_key {
|
|
1018
|
+
ggml_type src_type;
|
|
1019
|
+
ggml_type dst_type;
|
|
1020
|
+
|
|
1021
|
+
bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const {
|
|
1022
|
+
return src_type == other.src_type && dst_type == other.dst_type;
|
|
1023
|
+
}
|
|
1024
|
+
};
|
|
1025
|
+
|
|
1026
|
+
struct ggml_webgpu_cpy_pipeline_key_hash {
|
|
1027
|
+
size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const {
|
|
1028
|
+
size_t seed = 0;
|
|
1029
|
+
ggml_webgpu_hash_combine(seed, key.src_type);
|
|
1030
|
+
ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
1031
|
+
return seed;
|
|
1032
|
+
}
|
|
1033
|
+
};
|
|
1034
|
+
|
|
1035
|
+
/** Glu **/
|
|
1036
|
+
|
|
1037
|
+
struct ggml_webgpu_glu_pipeline_key {
|
|
1038
|
+
ggml_glu_op glu_op;
|
|
1039
|
+
ggml_type type;
|
|
1040
|
+
bool split;
|
|
1041
|
+
|
|
1042
|
+
bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
|
|
1043
|
+
return glu_op == other.glu_op && type == other.type && split == other.split;
|
|
1044
|
+
}
|
|
1045
|
+
};
|
|
1046
|
+
|
|
1047
|
+
struct ggml_webgpu_glu_pipeline_key_hash {
|
|
1048
|
+
size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const {
|
|
1049
|
+
size_t seed = 0;
|
|
1050
|
+
ggml_webgpu_hash_combine(seed, key.glu_op);
|
|
1051
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
1052
|
+
ggml_webgpu_hash_combine(seed, key.split);
|
|
1053
|
+
return seed;
|
|
1054
|
+
}
|
|
1055
|
+
};
|
|
1056
|
+
|
|
1057
|
+
/** Rope **/
|
|
1058
|
+
|
|
1059
|
+
struct ggml_webgpu_rope_pipeline_key {
|
|
1060
|
+
ggml_type type;
|
|
1061
|
+
bool inplace;
|
|
1062
|
+
bool has_ff;
|
|
1063
|
+
|
|
1064
|
+
bool operator==(const ggml_webgpu_rope_pipeline_key & other) const {
|
|
1065
|
+
return type == other.type && inplace == other.inplace && has_ff == other.has_ff;
|
|
1066
|
+
}
|
|
1067
|
+
};
|
|
1068
|
+
|
|
1069
|
+
struct ggml_webgpu_rope_pipeline_key_hash {
|
|
1070
|
+
size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const {
|
|
1071
|
+
size_t seed = 0;
|
|
1072
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
1073
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
1074
|
+
ggml_webgpu_hash_combine(seed, key.has_ff);
|
|
1075
|
+
return seed;
|
|
1076
|
+
}
|
|
1077
|
+
};
|
|
1078
|
+
|
|
1079
|
+
/** SoftMax **/
|
|
1080
|
+
|
|
1081
|
+
struct ggml_webgpu_soft_max_pipeline_key {
|
|
1082
|
+
ggml_type mask_type;
|
|
1083
|
+
bool has_mask;
|
|
1084
|
+
bool has_sink;
|
|
1085
|
+
bool inplace;
|
|
1086
|
+
|
|
1087
|
+
bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const {
|
|
1088
|
+
return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink &&
|
|
1089
|
+
inplace == other.inplace;
|
|
1090
|
+
}
|
|
1091
|
+
};
|
|
1092
|
+
|
|
1093
|
+
struct ggml_webgpu_soft_max_pipeline_key_hash {
|
|
1094
|
+
size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const {
|
|
1095
|
+
size_t seed = 0;
|
|
1096
|
+
ggml_webgpu_hash_combine(seed, key.mask_type);
|
|
1097
|
+
ggml_webgpu_hash_combine(seed, key.has_mask);
|
|
1098
|
+
ggml_webgpu_hash_combine(seed, key.has_sink);
|
|
1099
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
1100
|
+
return seed;
|
|
1101
|
+
}
|
|
1102
|
+
};
|
|
1103
|
+
|
|
1104
|
+
/** MMVQ **/
|
|
1105
|
+
|
|
1106
|
+
inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
|
|
1107
|
+
const ggml_tensor * src1,
|
|
1108
|
+
bool supports_dot_product,
|
|
1109
|
+
const std::string & vendor) {
|
|
1110
|
+
if (src1->ne[1] == 1) {
|
|
1111
|
+
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
|
|
1112
|
+
if (supports_dp4a && supports_dot_product) {
|
|
1113
|
+
switch (src1->type) {
|
|
1114
|
+
case GGML_TYPE_F32:
|
|
1115
|
+
switch (src0->type) {
|
|
1116
|
+
case GGML_TYPE_Q4_0:
|
|
1117
|
+
case GGML_TYPE_Q4_1:
|
|
1118
|
+
case GGML_TYPE_Q8_0:
|
|
1119
|
+
case GGML_TYPE_Q2_K:
|
|
1120
|
+
case GGML_TYPE_Q4_K:
|
|
1121
|
+
return src0->ne[0] % 4 == 0;
|
|
1122
|
+
default:
|
|
1123
|
+
break;
|
|
1124
|
+
}
|
|
1125
|
+
break;
|
|
1126
|
+
default:
|
|
1127
|
+
break;
|
|
1128
|
+
}
|
|
1129
|
+
}
|
|
1130
|
+
}
|
|
1131
|
+
return false;
|
|
1132
|
+
}
|
|
1133
|
+
|
|
1134
|
+
class ggml_webgpu_shader_lib {
|
|
1135
|
+
wgpu::Device device;
|
|
431
1136
|
pre_wgsl::Preprocessor preprocessor;
|
|
432
1137
|
|
|
433
1138
|
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
|
|
@@ -435,33 +1140,81 @@ class ggml_webgpu_shader_lib {
|
|
|
435
1140
|
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
|
|
436
1141
|
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
|
|
437
1142
|
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
|
|
1143
|
+
std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
|
|
1144
|
+
row_norm_pipelines; // op/inplace
|
|
1145
|
+
|
|
438
1146
|
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
|
|
439
|
-
get_rows_pipelines;
|
|
1147
|
+
get_rows_pipelines; // src_type, vectorized
|
|
440
1148
|
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
|
|
441
|
-
unary_pipelines;
|
|
1149
|
+
unary_pipelines; // type/op/inplace
|
|
442
1150
|
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
|
|
443
|
-
scale_pipelines;
|
|
1151
|
+
scale_pipelines; // inplace
|
|
1152
|
+
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
|
|
1153
|
+
solve_tri_pipelines; // type
|
|
1154
|
+
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
|
|
1155
|
+
ssm_conv_pipelines; // type/vectorized
|
|
1156
|
+
std::unordered_map<ggml_webgpu_ssm_scan_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_scan_pipeline_key_hash>
|
|
1157
|
+
ssm_scan_pipelines; // type/d_state
|
|
1158
|
+
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
|
|
1159
|
+
webgpu_pipeline,
|
|
1160
|
+
ggml_webgpu_gated_delta_net_pipeline_key_hash>
|
|
1161
|
+
gated_delta_net_pipelines; // type/S_v/kda
|
|
444
1162
|
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
|
|
445
|
-
pad_pipelines;
|
|
1163
|
+
pad_pipelines; // circular/non-circular
|
|
446
1164
|
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
|
|
447
|
-
binary_pipelines;
|
|
1165
|
+
binary_pipelines; // type/op/inplace/overlap/src_overlap
|
|
1166
|
+
std::unordered_map<ggml_webgpu_add_id_pipeline_key, webgpu_pipeline, ggml_webgpu_add_id_pipeline_key_hash>
|
|
1167
|
+
add_id_pipelines; // inplace
|
|
448
1168
|
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
|
|
449
|
-
concat_pipelines;
|
|
1169
|
+
concat_pipelines; // type
|
|
450
1170
|
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
|
|
451
|
-
repeat_pipelines;
|
|
1171
|
+
repeat_pipelines; // type
|
|
1172
|
+
std::unordered_map<ggml_webgpu_flash_attn_vec_pipeline_key,
|
|
1173
|
+
webgpu_pipeline,
|
|
1174
|
+
ggml_webgpu_flash_attn_vec_pipeline_key_hash>
|
|
1175
|
+
flash_attn_vec_pipelines;
|
|
452
1176
|
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
|
453
1177
|
flash_attn_pipelines;
|
|
454
|
-
std::unordered_map<
|
|
1178
|
+
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
|
|
455
1179
|
webgpu_pipeline,
|
|
456
|
-
|
|
457
|
-
|
|
1180
|
+
ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
|
|
1181
|
+
flash_attn_vec_reduce_pipelines;
|
|
1182
|
+
std::unordered_map<ggml_webgpu_flash_attn_blk_pipeline_key,
|
|
1183
|
+
webgpu_pipeline,
|
|
1184
|
+
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
|
|
1185
|
+
flash_attn_blk_pipelines;
|
|
458
1186
|
std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
|
|
459
|
-
mul_mat_vec_pipelines;
|
|
1187
|
+
mul_mat_vec_pipelines; // fast mat-vec (n==1)
|
|
460
1188
|
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
|
|
461
|
-
mul_mat_fast_pipelines;
|
|
1189
|
+
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
|
|
1190
|
+
std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash>
|
|
1191
|
+
quantize_q8_pipelines;
|
|
1192
|
+
std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
|
|
1193
|
+
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
|
|
1194
|
+
mul_mat_id_pipelines; // src0_type/src1_type
|
|
1195
|
+
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
|
|
1196
|
+
mul_mat_id_vec_pipelines; // src0_type/src1_type
|
|
462
1197
|
|
|
463
1198
|
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
|
464
1199
|
set_rows_pipelines;
|
|
1200
|
+
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
|
|
1201
|
+
std::unordered_map<ggml_webgpu_cpy_pipeline_key, webgpu_pipeline, ggml_webgpu_cpy_pipeline_key_hash> cpy_pipelines;
|
|
1202
|
+
std::unordered_map<ggml_webgpu_glu_pipeline_key, webgpu_pipeline, ggml_webgpu_glu_pipeline_key_hash> glu_pipelines;
|
|
1203
|
+
std::unordered_map<ggml_webgpu_rope_pipeline_key, webgpu_pipeline, ggml_webgpu_rope_pipeline_key_hash>
|
|
1204
|
+
rope_pipelines;
|
|
1205
|
+
std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
|
|
1206
|
+
soft_max_pipelines;
|
|
1207
|
+
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
|
|
1208
|
+
conv2d_pipelines;
|
|
1209
|
+
std::unordered_map<ggml_webgpu_im2col_pipeline_key, webgpu_pipeline, ggml_webgpu_im2col_pipeline_key_hash>
|
|
1210
|
+
im2col_pipelines;
|
|
1211
|
+
|
|
1212
|
+
std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key,
|
|
1213
|
+
webgpu_pipeline,
|
|
1214
|
+
ggml_webgpu_rms_norm_mul_pipeline_key_hash>
|
|
1215
|
+
rms_norm_mul_pipelines;
|
|
1216
|
+
std::unordered_map<ggml_webgpu_upscale_pipeline_key, webgpu_pipeline, ggml_webgpu_upscale_pipeline_key_hash>
|
|
1217
|
+
upscale_pipelines;
|
|
465
1218
|
|
|
466
1219
|
public:
|
|
467
1220
|
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
|
@@ -479,6 +1232,70 @@ class ggml_webgpu_shader_lib {
|
|
|
479
1232
|
return sum_rows_pipelines[1];
|
|
480
1233
|
}
|
|
481
1234
|
|
|
1235
|
+
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1236
|
+
ggml_webgpu_row_norm_pipeline_key key = {};
|
|
1237
|
+
key.op = context.dst->op;
|
|
1238
|
+
key.src_type = context.src0->type;
|
|
1239
|
+
key.dst_type = context.dst->type;
|
|
1240
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
1241
|
+
|
|
1242
|
+
auto it = row_norm_pipelines.find(key);
|
|
1243
|
+
if (it != row_norm_pipelines.end()) {
|
|
1244
|
+
return it->second;
|
|
1245
|
+
}
|
|
1246
|
+
std::vector<std::string> defines;
|
|
1247
|
+
std::string variant;
|
|
1248
|
+
|
|
1249
|
+
switch (key.op) {
|
|
1250
|
+
case GGML_OP_RMS_NORM:
|
|
1251
|
+
defines.push_back("RMS_NORM");
|
|
1252
|
+
variant = "rms_norm";
|
|
1253
|
+
break;
|
|
1254
|
+
case GGML_OP_NORM:
|
|
1255
|
+
defines.push_back("NORM");
|
|
1256
|
+
variant = "norm";
|
|
1257
|
+
break;
|
|
1258
|
+
case GGML_OP_L2_NORM:
|
|
1259
|
+
defines.push_back("L2_NORM");
|
|
1260
|
+
variant = "l2_norm";
|
|
1261
|
+
break;
|
|
1262
|
+
default:
|
|
1263
|
+
GGML_ABORT("Unsupported op for row_norm shader");
|
|
1264
|
+
}
|
|
1265
|
+
|
|
1266
|
+
if (key.inplace) {
|
|
1267
|
+
defines.push_back("INPLACE");
|
|
1268
|
+
variant += "_inplace";
|
|
1269
|
+
}
|
|
1270
|
+
|
|
1271
|
+
if (key.src_type == GGML_TYPE_F32) {
|
|
1272
|
+
defines.push_back("SRC_F32");
|
|
1273
|
+
variant += "_src_f32";
|
|
1274
|
+
} else if (key.src_type == GGML_TYPE_F16) {
|
|
1275
|
+
defines.push_back("SRC_F16");
|
|
1276
|
+
variant += "_src_f16";
|
|
1277
|
+
}
|
|
1278
|
+
|
|
1279
|
+
if (key.dst_type == GGML_TYPE_F32) {
|
|
1280
|
+
defines.push_back("DST_F32");
|
|
1281
|
+
variant += "_dst_f32";
|
|
1282
|
+
} else if (key.dst_type == GGML_TYPE_F16) {
|
|
1283
|
+
defines.push_back("DST_F16");
|
|
1284
|
+
variant += "_dst_f16";
|
|
1285
|
+
}
|
|
1286
|
+
|
|
1287
|
+
const uint32_t row_norm_wg_size = 128u;
|
|
1288
|
+
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
|
|
1289
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
1290
|
+
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
|
|
1291
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1292
|
+
decisions->wg_size = wg_size;
|
|
1293
|
+
decisions->inplace = key.inplace;
|
|
1294
|
+
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1295
|
+
row_norm_pipelines[key].context = decisions;
|
|
1296
|
+
return row_norm_pipelines[key];
|
|
1297
|
+
}
|
|
1298
|
+
|
|
482
1299
|
webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
483
1300
|
bool vec4 = context.src0->ne[0] % 4 == 0;
|
|
484
1301
|
|
|
@@ -500,9 +1317,13 @@ class ggml_webgpu_shader_lib {
|
|
|
500
1317
|
}
|
|
501
1318
|
|
|
502
1319
|
webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
1320
|
+
const bool quantized = ggml_is_quantized(context.dst->type);
|
|
1321
|
+
ggml_webgpu_set_rows_pipeline_key key = {};
|
|
1322
|
+
key.dst_type = context.dst->type;
|
|
1323
|
+
key.vec4 =
|
|
1324
|
+
(context.dst->type == GGML_TYPE_F32 || context.dst->type == GGML_TYPE_F16) && context.src0->ne[0] % 4 == 0;
|
|
1325
|
+
key.i64_idx = context.src1->type == GGML_TYPE_I64;
|
|
1326
|
+
key.pair_blocks = quantized && ((context.src0->ne[0] / ggml_blck_size(context.dst->type)) % 2 == 0);
|
|
506
1327
|
|
|
507
1328
|
auto it = set_rows_pipelines.find(key);
|
|
508
1329
|
if (it != set_rows_pipelines.end()) {
|
|
@@ -521,6 +1342,14 @@ class ggml_webgpu_shader_lib {
|
|
|
521
1342
|
defines.push_back("DST_F16");
|
|
522
1343
|
variant += "_dstf16";
|
|
523
1344
|
break;
|
|
1345
|
+
case GGML_TYPE_Q8_0:
|
|
1346
|
+
defines.push_back("DST_Q8_0");
|
|
1347
|
+
variant += "_dstq8_0";
|
|
1348
|
+
break;
|
|
1349
|
+
case GGML_TYPE_Q4_0:
|
|
1350
|
+
defines.push_back("DST_Q4_0");
|
|
1351
|
+
variant += "_dstq4_0";
|
|
1352
|
+
break;
|
|
524
1353
|
default:
|
|
525
1354
|
GGML_ABORT("Unsupported dst type for set_rows shader");
|
|
526
1355
|
}
|
|
@@ -533,19 +1362,68 @@ class ggml_webgpu_shader_lib {
|
|
|
533
1362
|
defines.push_back("I64_IDX");
|
|
534
1363
|
variant += "_i64idx";
|
|
535
1364
|
}
|
|
1365
|
+
if (key.pair_blocks) {
|
|
1366
|
+
defines.push_back("PAIR_BLOCKS");
|
|
1367
|
+
variant += "_pair_blocks";
|
|
1368
|
+
}
|
|
536
1369
|
|
|
537
1370
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
538
1371
|
|
|
539
|
-
auto
|
|
540
|
-
auto
|
|
1372
|
+
const auto & shader_source = quantized ? wgsl_set_rows_quant : wgsl_set_rows;
|
|
1373
|
+
auto processed = preprocessor.preprocess(shader_source, defines);
|
|
1374
|
+
auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
|
|
541
1375
|
decisions->vec4 = key.vec4;
|
|
542
1376
|
decisions->i64_idx = key.i64_idx;
|
|
1377
|
+
decisions->pair_blocks = key.pair_blocks;
|
|
543
1378
|
decisions->wg_size = context.max_wg_size;
|
|
544
1379
|
set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
545
1380
|
set_rows_pipelines[key].context = decisions;
|
|
546
1381
|
return set_rows_pipelines[key];
|
|
547
1382
|
}
|
|
548
1383
|
|
|
1384
|
+
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1385
|
+
ggml_webgpu_set_pipeline_key key = {};
|
|
1386
|
+
key.type = context.dst->type;
|
|
1387
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
1388
|
+
|
|
1389
|
+
auto it = set_pipelines.find(key);
|
|
1390
|
+
if (it != set_pipelines.end()) {
|
|
1391
|
+
return it->second;
|
|
1392
|
+
}
|
|
1393
|
+
|
|
1394
|
+
std::vector<std::string> defines;
|
|
1395
|
+
std::string variant = "set";
|
|
1396
|
+
|
|
1397
|
+
switch (key.type) {
|
|
1398
|
+
case GGML_TYPE_F32:
|
|
1399
|
+
defines.push_back("TYPE_F32");
|
|
1400
|
+
variant += "_f32";
|
|
1401
|
+
break;
|
|
1402
|
+
case GGML_TYPE_I32:
|
|
1403
|
+
defines.push_back("TYPE_I32");
|
|
1404
|
+
variant += "_i32";
|
|
1405
|
+
break;
|
|
1406
|
+
default:
|
|
1407
|
+
GGML_ABORT("Unsupported type for set shader");
|
|
1408
|
+
}
|
|
1409
|
+
|
|
1410
|
+
if (key.inplace) {
|
|
1411
|
+
defines.push_back("INPLACE");
|
|
1412
|
+
variant += "_inplace";
|
|
1413
|
+
}
|
|
1414
|
+
|
|
1415
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1416
|
+
|
|
1417
|
+
auto processed = preprocessor.preprocess(wgsl_set, defines);
|
|
1418
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1419
|
+
decisions->wg_size = context.max_wg_size;
|
|
1420
|
+
decisions->inplace = key.inplace;
|
|
1421
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1422
|
+
pipeline.context = decisions;
|
|
1423
|
+
set_pipelines[key] = pipeline;
|
|
1424
|
+
return set_pipelines[key];
|
|
1425
|
+
}
|
|
1426
|
+
|
|
549
1427
|
webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
550
1428
|
auto it = cumsum_pipelines.find(1);
|
|
551
1429
|
if (it != cumsum_pipelines.end()) {
|
|
@@ -614,10 +1492,9 @@ class ggml_webgpu_shader_lib {
|
|
|
614
1492
|
|
|
615
1493
|
webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
616
1494
|
const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0;
|
|
617
|
-
ggml_webgpu_get_rows_pipeline_key key = {
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
};
|
|
1495
|
+
ggml_webgpu_get_rows_pipeline_key key = {};
|
|
1496
|
+
key.src_type = context.src0->type;
|
|
1497
|
+
key.vectorized = (int) vectorized;
|
|
621
1498
|
|
|
622
1499
|
auto it = get_rows_pipelines.find(key);
|
|
623
1500
|
if (it != get_rows_pipelines.end()) {
|
|
@@ -632,6 +1509,7 @@ class ggml_webgpu_shader_lib {
|
|
|
632
1509
|
|
|
633
1510
|
switch (key.src_type) {
|
|
634
1511
|
case GGML_TYPE_F32:
|
|
1512
|
+
defines.push_back("FLOAT_PARALLEL");
|
|
635
1513
|
if (key.vectorized) {
|
|
636
1514
|
defines.push_back("F32_VEC");
|
|
637
1515
|
defines.push_back("SRC_TYPE=vec4<f32>");
|
|
@@ -646,6 +1524,7 @@ class ggml_webgpu_shader_lib {
|
|
|
646
1524
|
variant += "_f32";
|
|
647
1525
|
break;
|
|
648
1526
|
case GGML_TYPE_F16:
|
|
1527
|
+
defines.push_back("FLOAT_PARALLEL");
|
|
649
1528
|
defines.push_back("F16");
|
|
650
1529
|
defines.push_back("SRC_TYPE=f16");
|
|
651
1530
|
defines.push_back("DST_TYPE=f32");
|
|
@@ -653,6 +1532,7 @@ class ggml_webgpu_shader_lib {
|
|
|
653
1532
|
variant += "_f16";
|
|
654
1533
|
break;
|
|
655
1534
|
case GGML_TYPE_I32:
|
|
1535
|
+
defines.push_back("FLOAT_PARALLEL");
|
|
656
1536
|
defines.push_back("I32");
|
|
657
1537
|
defines.push_back("SRC_TYPE=i32");
|
|
658
1538
|
defines.push_back("DST_TYPE=i32");
|
|
@@ -664,21 +1544,50 @@ class ggml_webgpu_shader_lib {
|
|
|
664
1544
|
std::string type_upper = type_str;
|
|
665
1545
|
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
666
1546
|
|
|
1547
|
+
switch (key.src_type) {
|
|
1548
|
+
case GGML_TYPE_Q1_0:
|
|
1549
|
+
case GGML_TYPE_Q4_0:
|
|
1550
|
+
case GGML_TYPE_Q5_0:
|
|
1551
|
+
case GGML_TYPE_Q8_0:
|
|
1552
|
+
case GGML_TYPE_Q3_K:
|
|
1553
|
+
case GGML_TYPE_Q6_K:
|
|
1554
|
+
case GGML_TYPE_IQ2_XXS:
|
|
1555
|
+
case GGML_TYPE_IQ2_XS:
|
|
1556
|
+
case GGML_TYPE_IQ2_S:
|
|
1557
|
+
case GGML_TYPE_IQ3_XXS:
|
|
1558
|
+
case GGML_TYPE_IQ3_S:
|
|
1559
|
+
case GGML_TYPE_IQ1_S:
|
|
1560
|
+
case GGML_TYPE_IQ4_NL:
|
|
1561
|
+
case GGML_TYPE_MXFP4:
|
|
1562
|
+
{
|
|
1563
|
+
// Quantized types using u32 buffers for portability.
|
|
1564
|
+
defines.push_back("SRC_TYPE=u32");
|
|
1565
|
+
defines.push_back("U32_DEQUANT_HELPERS");
|
|
1566
|
+
break;
|
|
1567
|
+
}
|
|
1568
|
+
default:
|
|
1569
|
+
{
|
|
1570
|
+
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
|
1571
|
+
}
|
|
1572
|
+
}
|
|
1573
|
+
|
|
667
1574
|
defines.push_back("BYTE_HELPERS");
|
|
668
1575
|
defines.push_back(type_upper + "_T");
|
|
669
1576
|
defines.push_back(type_upper);
|
|
670
1577
|
defines.push_back(type_upper + "_SCALE_MIN");
|
|
671
1578
|
defines.push_back(type_upper + "_TABLES");
|
|
672
1579
|
defines.push_back(type_upper + "_GRID");
|
|
1580
|
+
defines.push_back(type_upper + "_LUT");
|
|
673
1581
|
|
|
674
1582
|
variant += "_";
|
|
675
1583
|
variant += type_str;
|
|
676
1584
|
|
|
677
|
-
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
|
678
1585
|
defines.push_back("DST_TYPE=f32");
|
|
679
1586
|
|
|
680
|
-
if (
|
|
681
|
-
|
|
1587
|
+
if (key.src_type == GGML_TYPE_Q1_0) {
|
|
1588
|
+
defines.push_back("BLOCK_SIZE=128u");
|
|
1589
|
+
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
|
1590
|
+
key.src_type == GGML_TYPE_IQ4_NL || key.src_type == GGML_TYPE_MXFP4) {
|
|
682
1591
|
defines.push_back("BLOCK_SIZE=32u");
|
|
683
1592
|
} else if (key.src_type >= GGML_TYPE_Q2_K) {
|
|
684
1593
|
defines.push_back("BLOCK_SIZE=256u");
|
|
@@ -705,7 +1614,8 @@ class ggml_webgpu_shader_lib {
|
|
|
705
1614
|
}
|
|
706
1615
|
|
|
707
1616
|
webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
708
|
-
ggml_webgpu_scale_pipeline_key key = {
|
|
1617
|
+
ggml_webgpu_scale_pipeline_key key = {};
|
|
1618
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
709
1619
|
|
|
710
1620
|
auto it = scale_pipelines.find(key);
|
|
711
1621
|
if (it != scale_pipelines.end()) {
|
|
@@ -725,14 +1635,189 @@ class ggml_webgpu_shader_lib {
|
|
|
725
1635
|
auto processed = preprocessor.preprocess(wgsl_scale, defines);
|
|
726
1636
|
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
727
1637
|
decisions->wg_size = context.max_wg_size;
|
|
1638
|
+
decisions->inplace = key.inplace;
|
|
728
1639
|
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
729
1640
|
pipeline.context = decisions;
|
|
730
1641
|
scale_pipelines[key] = pipeline;
|
|
731
1642
|
return scale_pipelines[key];
|
|
732
1643
|
}
|
|
733
1644
|
|
|
1645
|
+
webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1646
|
+
ggml_webgpu_solve_tri_pipeline_key key = {};
|
|
1647
|
+
key.type = context.dst->type;
|
|
1648
|
+
key.n = (int) context.src0->ne[0];
|
|
1649
|
+
key.k = (int) context.src1->ne[0];
|
|
1650
|
+
|
|
1651
|
+
auto it = solve_tri_pipelines.find(key);
|
|
1652
|
+
if (it != solve_tri_pipelines.end()) {
|
|
1653
|
+
return it->second;
|
|
1654
|
+
}
|
|
1655
|
+
|
|
1656
|
+
std::vector<std::string> defines;
|
|
1657
|
+
std::string variant = "solve_tri";
|
|
1658
|
+
|
|
1659
|
+
switch (key.type) {
|
|
1660
|
+
case GGML_TYPE_F32:
|
|
1661
|
+
variant += "_f32";
|
|
1662
|
+
break;
|
|
1663
|
+
default:
|
|
1664
|
+
GGML_ABORT("Unsupported type for solve_tri shader");
|
|
1665
|
+
}
|
|
1666
|
+
|
|
1667
|
+
const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size);
|
|
1668
|
+
const uint32_t k_tile = wg_size;
|
|
1669
|
+
const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES;
|
|
1670
|
+
const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row);
|
|
1671
|
+
|
|
1672
|
+
defines.push_back(std::string("N=") + std::to_string(key.n));
|
|
1673
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
1674
|
+
defines.push_back(std::string("K_TILE=") + std::to_string(k_tile));
|
|
1675
|
+
defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n));
|
|
1676
|
+
|
|
1677
|
+
auto processed = preprocessor.preprocess(wgsl_solve_tri, defines);
|
|
1678
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1679
|
+
decisions->wg_size = wg_size;
|
|
1680
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1681
|
+
pipeline.context = decisions;
|
|
1682
|
+
solve_tri_pipelines[key] = pipeline;
|
|
1683
|
+
return solve_tri_pipelines[key];
|
|
1684
|
+
}
|
|
1685
|
+
|
|
1686
|
+
webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1687
|
+
ggml_webgpu_ssm_conv_pipeline_key key = {};
|
|
1688
|
+
key.type = context.dst->type;
|
|
1689
|
+
key.vectorized = context.src1->ne[0] == 4;
|
|
1690
|
+
|
|
1691
|
+
auto it = ssm_conv_pipelines.find(key);
|
|
1692
|
+
if (it != ssm_conv_pipelines.end()) {
|
|
1693
|
+
return it->second;
|
|
1694
|
+
}
|
|
1695
|
+
|
|
1696
|
+
std::vector<std::string> defines;
|
|
1697
|
+
std::string variant = "ssm_conv";
|
|
1698
|
+
|
|
1699
|
+
switch (key.type) {
|
|
1700
|
+
case GGML_TYPE_F32:
|
|
1701
|
+
variant += "_f32";
|
|
1702
|
+
break;
|
|
1703
|
+
default:
|
|
1704
|
+
GGML_ABORT("Unsupported type for ssm_conv shader");
|
|
1705
|
+
}
|
|
1706
|
+
|
|
1707
|
+
if (key.vectorized) {
|
|
1708
|
+
defines.push_back("VECTORIZED");
|
|
1709
|
+
variant += "_vec4";
|
|
1710
|
+
}
|
|
1711
|
+
|
|
1712
|
+
constexpr uint32_t block_size = 32u;
|
|
1713
|
+
constexpr uint32_t tokens_per_wg = 8u;
|
|
1714
|
+
|
|
1715
|
+
defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u");
|
|
1716
|
+
defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u");
|
|
1717
|
+
|
|
1718
|
+
auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines);
|
|
1719
|
+
auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>();
|
|
1720
|
+
decisions->block_size = block_size;
|
|
1721
|
+
decisions->tokens_per_wg = tokens_per_wg;
|
|
1722
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1723
|
+
pipeline.context = decisions;
|
|
1724
|
+
ssm_conv_pipelines[key] = pipeline;
|
|
1725
|
+
return ssm_conv_pipelines[key];
|
|
1726
|
+
}
|
|
1727
|
+
|
|
1728
|
+
webgpu_pipeline get_ssm_scan_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1729
|
+
ggml_webgpu_ssm_scan_pipeline_key key = {};
|
|
1730
|
+
key.type = context.dst->type;
|
|
1731
|
+
key.d_state = (int) context.src0->ne[0];
|
|
1732
|
+
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
|
|
1733
|
+
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
|
1734
|
+
|
|
1735
|
+
auto it = ssm_scan_pipelines.find(key);
|
|
1736
|
+
if (it != ssm_scan_pipelines.end()) {
|
|
1737
|
+
return it->second;
|
|
1738
|
+
}
|
|
1739
|
+
|
|
1740
|
+
std::vector<std::string> defines;
|
|
1741
|
+
std::string variant = "ssm_scan";
|
|
1742
|
+
|
|
1743
|
+
switch (key.type) {
|
|
1744
|
+
case GGML_TYPE_F32:
|
|
1745
|
+
variant += "_f32";
|
|
1746
|
+
break;
|
|
1747
|
+
default:
|
|
1748
|
+
GGML_ABORT("Unsupported type for ssm_scan shader");
|
|
1749
|
+
}
|
|
1750
|
+
|
|
1751
|
+
const uint32_t wg_size = (uint32_t) key.d_state;
|
|
1752
|
+
|
|
1753
|
+
constexpr uint32_t tokens_per_tile = 4u;
|
|
1754
|
+
|
|
1755
|
+
defines.push_back("WG_SIZE=" + std::to_string(wg_size) + "u");
|
|
1756
|
+
defines.push_back("TOKENS_PER_TILE=" + std::to_string(tokens_per_tile) + "u");
|
|
1757
|
+
|
|
1758
|
+
if (context.supports_subgroups) {
|
|
1759
|
+
defines.push_back("USE_SUBGROUP_REDUCTION");
|
|
1760
|
+
variant += "_sg_reduce";
|
|
1761
|
+
} else {
|
|
1762
|
+
variant += "_wg_reduce";
|
|
1763
|
+
}
|
|
1764
|
+
|
|
1765
|
+
if (key.xbc_overlap) {
|
|
1766
|
+
defines.push_back("XBC_OVERLAP");
|
|
1767
|
+
}
|
|
1768
|
+
|
|
1769
|
+
variant += "_d" + std::to_string(key.d_state);
|
|
1770
|
+
|
|
1771
|
+
auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines);
|
|
1772
|
+
auto decisions = std::make_shared<ggml_webgpu_ssm_scan_shader_decisions>();
|
|
1773
|
+
decisions->wg_size = wg_size;
|
|
1774
|
+
decisions->tokens_per_tile = tokens_per_tile;
|
|
1775
|
+
decisions->xbc_overlap = key.xbc_overlap;
|
|
1776
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1777
|
+
pipeline.context = decisions;
|
|
1778
|
+
ssm_scan_pipelines[key] = pipeline;
|
|
1779
|
+
return ssm_scan_pipelines[key];
|
|
1780
|
+
}
|
|
1781
|
+
|
|
1782
|
+
webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1783
|
+
ggml_webgpu_gated_delta_net_pipeline_key key = {};
|
|
1784
|
+
key.type = context.dst->type;
|
|
1785
|
+
key.s_v = (int) context.src2->ne[0];
|
|
1786
|
+
key.kda = context.src3->ne[0] == context.src2->ne[0];
|
|
1787
|
+
|
|
1788
|
+
auto it = gated_delta_net_pipelines.find(key);
|
|
1789
|
+
if (it != gated_delta_net_pipelines.end()) {
|
|
1790
|
+
return it->second;
|
|
1791
|
+
}
|
|
1792
|
+
|
|
1793
|
+
std::vector<std::string> defines;
|
|
1794
|
+
std::string variant = "gated_delta_net";
|
|
1795
|
+
|
|
1796
|
+
switch (key.type) {
|
|
1797
|
+
case GGML_TYPE_F32:
|
|
1798
|
+
variant += "_f32";
|
|
1799
|
+
break;
|
|
1800
|
+
default:
|
|
1801
|
+
GGML_ABORT("Unsupported type for gated_delta_net shader");
|
|
1802
|
+
}
|
|
1803
|
+
|
|
1804
|
+
if (key.kda) {
|
|
1805
|
+
defines.push_back("KDA");
|
|
1806
|
+
variant += "_kda";
|
|
1807
|
+
}
|
|
1808
|
+
|
|
1809
|
+
defines.push_back("S_V=" + std::to_string(key.s_v) + "u");
|
|
1810
|
+
defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u");
|
|
1811
|
+
|
|
1812
|
+
auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines);
|
|
1813
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1814
|
+
gated_delta_net_pipelines[key] = pipeline;
|
|
1815
|
+
return gated_delta_net_pipelines[key];
|
|
1816
|
+
}
|
|
1817
|
+
|
|
734
1818
|
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
735
|
-
ggml_webgpu_pad_pipeline_key key = {
|
|
1819
|
+
ggml_webgpu_pad_pipeline_key key = {};
|
|
1820
|
+
key.circular = ggml_get_op_params_i32(context.dst, 8) != 0;
|
|
736
1821
|
|
|
737
1822
|
auto it = pad_pipelines.find(key);
|
|
738
1823
|
if (it != pad_pipelines.end()) {
|
|
@@ -758,16 +1843,54 @@ class ggml_webgpu_shader_lib {
|
|
|
758
1843
|
return pad_pipelines[key];
|
|
759
1844
|
}
|
|
760
1845
|
|
|
1846
|
+
webgpu_pipeline get_quantize_q8_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1847
|
+
ggml_webgpu_quantize_q8_pipeline_key key = {};
|
|
1848
|
+
key.src0_type = context.src0->type;
|
|
1849
|
+
|
|
1850
|
+
auto it = quantize_q8_pipelines.find(key);
|
|
1851
|
+
if (it != quantize_q8_pipelines.end()) {
|
|
1852
|
+
return it->second;
|
|
1853
|
+
}
|
|
1854
|
+
const char * shader_src = wgsl_quantize_q8;
|
|
1855
|
+
std::vector<std::string> defines;
|
|
1856
|
+
std::string variant = "quantize_q8";
|
|
1857
|
+
|
|
1858
|
+
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
|
1859
|
+
|
|
1860
|
+
defines.push_back("SRC1_INNER_TYPE=f32");
|
|
1861
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
1862
|
+
|
|
1863
|
+
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
|
1864
|
+
std::string src0_name = src0_traits->type_name;
|
|
1865
|
+
std::string type_upper = src0_name;
|
|
1866
|
+
variant += "_" + src0_name;
|
|
1867
|
+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
1868
|
+
|
|
1869
|
+
defines.push_back("MUL_ACC_" + type_upper);
|
|
1870
|
+
defines.push_back("Q8_1_T");
|
|
1871
|
+
|
|
1872
|
+
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
|
|
1873
|
+
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
|
|
1874
|
+
|
|
1875
|
+
auto processed = preprocessor.preprocess(shader_src, defines);
|
|
1876
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1877
|
+
decisions->wg_size = wg_size;
|
|
1878
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1879
|
+
pipeline.context = decisions;
|
|
1880
|
+
quantize_q8_pipelines[key] = pipeline;
|
|
1881
|
+
return quantize_q8_pipelines[key];
|
|
1882
|
+
}
|
|
1883
|
+
|
|
761
1884
|
webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
762
|
-
ggml_webgpu_mul_mat_vec_pipeline_key key = {
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
1885
|
+
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
|
|
1886
|
+
key.src0_type = context.src0->type;
|
|
1887
|
+
key.src1_type = context.src1->type;
|
|
1888
|
+
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
|
|
1889
|
+
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
|
1890
|
+
1 :
|
|
1891
|
+
0;
|
|
1892
|
+
key.use_mmvq =
|
|
1893
|
+
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
|
|
771
1894
|
|
|
772
1895
|
auto it = mul_mat_vec_pipelines.find(key);
|
|
773
1896
|
if (it != mul_mat_vec_pipelines.end()) {
|
|
@@ -775,7 +1898,8 @@ class ggml_webgpu_shader_lib {
|
|
|
775
1898
|
}
|
|
776
1899
|
|
|
777
1900
|
std::vector<std::string> defines;
|
|
778
|
-
std::string variant
|
|
1901
|
+
std::string variant = "mul_mat_vec";
|
|
1902
|
+
const char * shader_src = wgsl_mul_mat_vec;
|
|
779
1903
|
|
|
780
1904
|
// src0 type (matrix row)
|
|
781
1905
|
switch (context.src0->type) {
|
|
@@ -800,9 +1924,42 @@ class ggml_webgpu_shader_lib {
|
|
|
800
1924
|
|
|
801
1925
|
defines.push_back("BYTE_HELPERS");
|
|
802
1926
|
defines.push_back("MUL_ACC_" + type_upper);
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
1927
|
+
defines.push_back("U32_DEQUANT_HELPERS");
|
|
1928
|
+
defines.push_back("SRC0_INNER_TYPE=u32");
|
|
1929
|
+
switch (context.src0->type) {
|
|
1930
|
+
case GGML_TYPE_Q8_0:
|
|
1931
|
+
case GGML_TYPE_Q4_0:
|
|
1932
|
+
case GGML_TYPE_Q4_1:
|
|
1933
|
+
if (key.use_mmvq) {
|
|
1934
|
+
defines.push_back("LEGACY_QUANTS");
|
|
1935
|
+
}
|
|
1936
|
+
break;
|
|
1937
|
+
case GGML_TYPE_Q2_K:
|
|
1938
|
+
case GGML_TYPE_Q4_K:
|
|
1939
|
+
if (key.use_mmvq) {
|
|
1940
|
+
defines.push_back("K_QUANTS");
|
|
1941
|
+
}
|
|
1942
|
+
break;
|
|
1943
|
+
case GGML_TYPE_IQ1_S:
|
|
1944
|
+
case GGML_TYPE_IQ1_M:
|
|
1945
|
+
case GGML_TYPE_IQ2_S:
|
|
1946
|
+
case GGML_TYPE_IQ3_S:
|
|
1947
|
+
case GGML_TYPE_IQ4_NL:
|
|
1948
|
+
case GGML_TYPE_IQ4_XS:
|
|
1949
|
+
defines.push_back(type_upper + "_GRID");
|
|
1950
|
+
break;
|
|
1951
|
+
case GGML_TYPE_IQ2_XXS:
|
|
1952
|
+
case GGML_TYPE_IQ2_XS:
|
|
1953
|
+
case GGML_TYPE_IQ3_XXS:
|
|
1954
|
+
defines.push_back(type_upper + "_GRID");
|
|
1955
|
+
defines.push_back(type_upper + "_TABLES");
|
|
1956
|
+
break;
|
|
1957
|
+
case GGML_TYPE_MXFP4:
|
|
1958
|
+
defines.push_back(type_upper + "_LUT");
|
|
1959
|
+
break;
|
|
1960
|
+
default:
|
|
1961
|
+
break;
|
|
1962
|
+
}
|
|
806
1963
|
break;
|
|
807
1964
|
}
|
|
808
1965
|
}
|
|
@@ -825,25 +1982,32 @@ class ggml_webgpu_shader_lib {
|
|
|
825
1982
|
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
|
826
1983
|
|
|
827
1984
|
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
|
828
|
-
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
|
|
829
1985
|
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
|
|
830
1986
|
|
|
831
|
-
if (key.src0_type
|
|
832
|
-
|
|
1987
|
+
if (key.src0_type == GGML_TYPE_Q1_0) {
|
|
1988
|
+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
|
1989
|
+
} else if (key.src0_type >= GGML_TYPE_Q2_K) {
|
|
833
1990
|
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
|
|
834
1991
|
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
|
|
835
|
-
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
|
|
836
1992
|
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
|
837
1993
|
}
|
|
838
1994
|
|
|
1995
|
+
if (key.use_mmvq) {
|
|
1996
|
+
defines.push_back("MMVQ");
|
|
1997
|
+
defines.push_back("Q8_1_T");
|
|
1998
|
+
}
|
|
1999
|
+
|
|
839
2000
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
840
|
-
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
|
|
841
2001
|
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
|
|
2002
|
+
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
|
|
2003
|
+
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
|
|
2004
|
+
if (key.vectorized) {
|
|
2005
|
+
variant += "_vectorized";
|
|
2006
|
+
}
|
|
842
2007
|
|
|
843
|
-
auto processed = preprocessor.preprocess(
|
|
2008
|
+
auto processed = preprocessor.preprocess(shader_src, defines);
|
|
844
2009
|
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
|
|
845
2010
|
decisions->wg_size = wg_size;
|
|
846
|
-
decisions->tile_k = tile_k;
|
|
847
2011
|
decisions->outputs_per_wg = outputs_per_wg;
|
|
848
2012
|
decisions->vec_size = key.vectorized ? 4 : 1;
|
|
849
2013
|
|
|
@@ -854,15 +2018,14 @@ class ggml_webgpu_shader_lib {
|
|
|
854
2018
|
}
|
|
855
2019
|
|
|
856
2020
|
webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
857
|
-
ggml_webgpu_mul_mat_pipeline_key key = {
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
};
|
|
2021
|
+
ggml_webgpu_mul_mat_pipeline_key key = {};
|
|
2022
|
+
key.src0_type = context.src0->type;
|
|
2023
|
+
key.src1_type = context.src1->type;
|
|
2024
|
+
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
|
|
2025
|
+
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
|
2026
|
+
1 :
|
|
2027
|
+
0;
|
|
2028
|
+
key.use_subgroup_matrix = context.supports_subgroup_matrix;
|
|
866
2029
|
|
|
867
2030
|
auto it = mul_mat_fast_pipelines.find(key);
|
|
868
2031
|
if (it != mul_mat_fast_pipelines.end()) {
|
|
@@ -915,9 +2078,30 @@ class ggml_webgpu_shader_lib {
|
|
|
915
2078
|
defines.push_back("MUL_ACC_" + type_upper);
|
|
916
2079
|
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
|
|
917
2080
|
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
2081
|
+
defines.push_back("U32_DEQUANT_HELPERS");
|
|
2082
|
+
defines.push_back("SRC0_INNER_TYPE=u32");
|
|
2083
|
+
|
|
2084
|
+
switch (context.src0->type) {
|
|
2085
|
+
case GGML_TYPE_IQ1_S:
|
|
2086
|
+
case GGML_TYPE_IQ1_M:
|
|
2087
|
+
case GGML_TYPE_IQ4_NL:
|
|
2088
|
+
case GGML_TYPE_IQ4_XS:
|
|
2089
|
+
defines.push_back(type_upper + "_GRID");
|
|
2090
|
+
break;
|
|
2091
|
+
case GGML_TYPE_IQ2_XXS:
|
|
2092
|
+
case GGML_TYPE_IQ2_XS:
|
|
2093
|
+
case GGML_TYPE_IQ2_S:
|
|
2094
|
+
case GGML_TYPE_IQ3_XXS:
|
|
2095
|
+
case GGML_TYPE_IQ3_S:
|
|
2096
|
+
defines.push_back(type_upper + "_GRID");
|
|
2097
|
+
defines.push_back(type_upper + "_TABLES");
|
|
2098
|
+
break;
|
|
2099
|
+
case GGML_TYPE_MXFP4:
|
|
2100
|
+
defines.push_back(type_upper + "_LUT");
|
|
2101
|
+
break;
|
|
2102
|
+
default:
|
|
2103
|
+
break;
|
|
2104
|
+
}
|
|
921
2105
|
|
|
922
2106
|
variant += std::string("_") + src0_name;
|
|
923
2107
|
break;
|
|
@@ -927,13 +2111,22 @@ class ggml_webgpu_shader_lib {
|
|
|
927
2111
|
// VEC/SCALAR controls
|
|
928
2112
|
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
|
929
2113
|
|
|
2114
|
+
const bool is_quant = ggml_is_quantized(context.src0->type);
|
|
2115
|
+
|
|
2116
|
+
uint32_t tile_k;
|
|
2117
|
+
if (key.use_subgroup_matrix) {
|
|
2118
|
+
tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
|
|
2119
|
+
} else {
|
|
2120
|
+
tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
|
2121
|
+
}
|
|
2122
|
+
|
|
930
2123
|
// Tiles
|
|
931
2124
|
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
|
932
2125
|
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
|
|
933
|
-
defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
|
|
934
2126
|
|
|
935
2127
|
// Subgroup matrix specifics
|
|
936
2128
|
if (key.use_subgroup_matrix) {
|
|
2129
|
+
defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
|
|
937
2130
|
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
|
|
938
2131
|
defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
|
|
939
2132
|
defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
|
|
@@ -953,12 +2146,13 @@ class ggml_webgpu_shader_lib {
|
|
|
953
2146
|
if (!key.use_subgroup_matrix) {
|
|
954
2147
|
defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
|
|
955
2148
|
defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
|
|
2149
|
+
defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
|
|
956
2150
|
}
|
|
957
2151
|
|
|
958
2152
|
auto processed = preprocessor.preprocess(shader_src, defines);
|
|
959
2153
|
|
|
960
2154
|
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
|
|
961
|
-
decisions->tile_k =
|
|
2155
|
+
decisions->tile_k = tile_k;
|
|
962
2156
|
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
|
|
963
2157
|
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
|
|
964
2158
|
decisions->use_subgroup_matrix = key.use_subgroup_matrix;
|
|
@@ -981,84 +2175,276 @@ class ggml_webgpu_shader_lib {
|
|
|
981
2175
|
return mul_mat_fast_pipelines[key];
|
|
982
2176
|
}
|
|
983
2177
|
|
|
984
|
-
webgpu_pipeline
|
|
985
|
-
|
|
986
|
-
|
|
2178
|
+
webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2179
|
+
auto it = mul_mat_id_gather_pipelines.find(1);
|
|
2180
|
+
if (it != mul_mat_id_gather_pipelines.end()) {
|
|
2181
|
+
return it->second;
|
|
2182
|
+
}
|
|
2183
|
+
std::vector<std::string> defines;
|
|
2184
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
987
2185
|
|
|
988
|
-
auto
|
|
989
|
-
|
|
2186
|
+
auto processed = preprocessor.preprocess(wgsl_mul_mat_id_gather, defines);
|
|
2187
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
2188
|
+
decisions->wg_size = context.max_wg_size;
|
|
2189
|
+
|
|
2190
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, "mul_mat_id_gather");
|
|
2191
|
+
pipeline.context = decisions;
|
|
2192
|
+
mul_mat_id_gather_pipelines[1] = pipeline;
|
|
2193
|
+
return pipeline;
|
|
2194
|
+
}
|
|
2195
|
+
|
|
2196
|
+
webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2197
|
+
ggml_webgpu_mul_mat_id_pipeline_key key = {};
|
|
2198
|
+
key.src0_type = context.src0->type;
|
|
2199
|
+
key.src1_type = context.src1->type;
|
|
2200
|
+
key.n_experts = context.src0->ne[2];
|
|
2201
|
+
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
|
|
2202
|
+
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
|
2203
|
+
1 :
|
|
2204
|
+
0;
|
|
2205
|
+
|
|
2206
|
+
auto it = mul_mat_id_pipelines.find(key);
|
|
2207
|
+
if (it != mul_mat_id_pipelines.end()) {
|
|
990
2208
|
return it->second;
|
|
991
2209
|
}
|
|
992
2210
|
|
|
993
2211
|
std::vector<std::string> defines;
|
|
994
|
-
std::string variant = "
|
|
2212
|
+
std::string variant = "mul_mat_id";
|
|
2213
|
+
defines.push_back("MUL_MAT_ID");
|
|
995
2214
|
|
|
2215
|
+
// src1 type
|
|
996
2216
|
switch (context.src1->type) {
|
|
997
2217
|
case GGML_TYPE_F32:
|
|
998
|
-
defines.push_back("
|
|
999
|
-
variant += "_f32";
|
|
2218
|
+
defines.push_back("SRC1_INNER_TYPE=f32");
|
|
1000
2219
|
break;
|
|
1001
2220
|
case GGML_TYPE_F16:
|
|
1002
|
-
defines.push_back("
|
|
1003
|
-
variant += "_f16";
|
|
2221
|
+
defines.push_back("SRC1_INNER_TYPE=f16");
|
|
1004
2222
|
break;
|
|
1005
2223
|
default:
|
|
1006
|
-
GGML_ABORT("Unsupported src1 type for mul_mat
|
|
2224
|
+
GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
|
|
1007
2225
|
}
|
|
1008
2226
|
|
|
2227
|
+
// src0 type
|
|
1009
2228
|
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
|
1010
2229
|
const char * src0_name = src0_traits->type_name;
|
|
1011
2230
|
|
|
1012
2231
|
switch (context.src0->type) {
|
|
1013
2232
|
case GGML_TYPE_F32:
|
|
1014
|
-
defines.push_back("
|
|
1015
|
-
defines.push_back("
|
|
2233
|
+
defines.push_back("SRC0_INNER_TYPE=f32");
|
|
2234
|
+
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
|
2235
|
+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
|
1016
2236
|
variant += "_f32";
|
|
1017
2237
|
break;
|
|
1018
2238
|
case GGML_TYPE_F16:
|
|
1019
|
-
defines.push_back("
|
|
1020
|
-
defines.push_back("
|
|
2239
|
+
defines.push_back("SRC0_INNER_TYPE=f16");
|
|
2240
|
+
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
|
2241
|
+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
|
1021
2242
|
variant += "_f16";
|
|
1022
2243
|
break;
|
|
1023
2244
|
default:
|
|
1024
2245
|
{
|
|
1025
|
-
// quantized types
|
|
1026
2246
|
std::string type_upper = src0_name;
|
|
1027
2247
|
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
1028
2248
|
|
|
1029
|
-
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
|
1030
2249
|
defines.push_back("BYTE_HELPERS");
|
|
1031
|
-
defines.push_back(
|
|
1032
|
-
defines.push_back(
|
|
1033
|
-
defines.push_back(
|
|
1034
|
-
defines.push_back(
|
|
1035
|
-
|
|
2250
|
+
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
|
|
2251
|
+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
|
2252
|
+
defines.push_back("U32_DEQUANT_HELPERS");
|
|
2253
|
+
defines.push_back("SRC0_INNER_TYPE=u32");
|
|
2254
|
+
|
|
2255
|
+
switch (context.src0->type) {
|
|
2256
|
+
case GGML_TYPE_IQ1_S:
|
|
2257
|
+
case GGML_TYPE_IQ1_M:
|
|
2258
|
+
case GGML_TYPE_IQ4_NL:
|
|
2259
|
+
case GGML_TYPE_IQ4_XS:
|
|
2260
|
+
defines.push_back(type_upper + "_GRID");
|
|
2261
|
+
break;
|
|
2262
|
+
case GGML_TYPE_IQ2_XXS:
|
|
2263
|
+
case GGML_TYPE_IQ2_XS:
|
|
2264
|
+
case GGML_TYPE_IQ2_S:
|
|
2265
|
+
case GGML_TYPE_IQ3_XXS:
|
|
2266
|
+
case GGML_TYPE_IQ3_S:
|
|
2267
|
+
defines.push_back(type_upper + "_GRID");
|
|
2268
|
+
defines.push_back(type_upper + "_TABLES");
|
|
2269
|
+
break;
|
|
2270
|
+
case GGML_TYPE_MXFP4:
|
|
2271
|
+
defines.push_back(type_upper + "_LUT");
|
|
2272
|
+
break;
|
|
2273
|
+
default:
|
|
2274
|
+
break;
|
|
2275
|
+
}
|
|
1036
2276
|
|
|
1037
2277
|
variant += std::string("_") + src0_name;
|
|
1038
2278
|
break;
|
|
1039
2279
|
}
|
|
1040
2280
|
}
|
|
1041
2281
|
|
|
1042
|
-
|
|
2282
|
+
// VEC/SCALAR controls
|
|
2283
|
+
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
|
1043
2284
|
|
|
1044
|
-
|
|
1045
|
-
|
|
2285
|
+
// mul_mat_id is register-tile only.
|
|
2286
|
+
const uint32_t tile_k =
|
|
2287
|
+
ggml_is_quantized(context.src0->type) ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
|
2288
|
+
|
|
2289
|
+
// Tiles
|
|
2290
|
+
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
|
2291
|
+
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
|
|
2292
|
+
defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
|
|
2293
|
+
|
|
2294
|
+
defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
|
|
2295
|
+
defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
|
|
2296
|
+
|
|
2297
|
+
// variant suffix for src1 type
|
|
2298
|
+
variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
|
|
2299
|
+
if (key.vectorized) {
|
|
2300
|
+
variant += "_vectorized";
|
|
2301
|
+
}
|
|
2302
|
+
|
|
2303
|
+
auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines);
|
|
2304
|
+
|
|
2305
|
+
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
|
|
2306
|
+
decisions->tile_k = tile_k;
|
|
2307
|
+
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
|
|
2308
|
+
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
|
|
2309
|
+
decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
|
|
2310
|
+
decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
|
|
2311
|
+
decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N;
|
|
2312
|
+
|
|
2313
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2314
|
+
pipeline.context = decisions;
|
|
2315
|
+
mul_mat_id_pipelines[key] = pipeline;
|
|
2316
|
+
return mul_mat_id_pipelines[key];
|
|
2317
|
+
}
|
|
2318
|
+
|
|
2319
|
+
webgpu_pipeline get_mul_mat_id_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2320
|
+
ggml_webgpu_mul_mat_id_pipeline_key key = {};
|
|
2321
|
+
key.src0_type = context.src0->type;
|
|
2322
|
+
key.src1_type = context.src1->type;
|
|
2323
|
+
key.n_experts = context.src0->ne[2];
|
|
2324
|
+
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
|
|
2325
|
+
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
|
2326
|
+
1 :
|
|
2327
|
+
0;
|
|
2328
|
+
|
|
2329
|
+
auto it = mul_mat_id_vec_pipelines.find(key);
|
|
2330
|
+
if (it != mul_mat_id_vec_pipelines.end()) {
|
|
2331
|
+
return it->second;
|
|
2332
|
+
}
|
|
2333
|
+
|
|
2334
|
+
std::vector<std::string> defines;
|
|
2335
|
+
std::string variant = "mul_mat_id_vec";
|
|
2336
|
+
const char * shader_src = wgsl_mul_mat_id_vec;
|
|
2337
|
+
|
|
2338
|
+
// src1 type
|
|
2339
|
+
switch (context.src1->type) {
|
|
2340
|
+
case GGML_TYPE_F32:
|
|
2341
|
+
defines.push_back("SRC1_INNER_TYPE=f32");
|
|
2342
|
+
break;
|
|
2343
|
+
case GGML_TYPE_F16:
|
|
2344
|
+
defines.push_back("SRC1_INNER_TYPE=f16");
|
|
2345
|
+
break;
|
|
2346
|
+
default:
|
|
2347
|
+
GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
|
|
2348
|
+
}
|
|
2349
|
+
|
|
2350
|
+
// src0 type
|
|
2351
|
+
switch (context.src0->type) {
|
|
2352
|
+
case GGML_TYPE_F32:
|
|
2353
|
+
defines.push_back("SRC0_INNER_TYPE=f32");
|
|
2354
|
+
defines.push_back("MUL_ACC_FLOAT");
|
|
2355
|
+
variant += "_f32";
|
|
2356
|
+
break;
|
|
2357
|
+
case GGML_TYPE_F16:
|
|
2358
|
+
defines.push_back("SRC0_INNER_TYPE=f16");
|
|
2359
|
+
defines.push_back("MUL_ACC_FLOAT");
|
|
2360
|
+
variant += "_f16";
|
|
2361
|
+
break;
|
|
2362
|
+
default:
|
|
2363
|
+
{
|
|
2364
|
+
// Quantized types: use helpers but accumulate in f16
|
|
2365
|
+
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
|
2366
|
+
std::string src0_name = src0_traits->type_name;
|
|
2367
|
+
std::string type_upper = src0_name;
|
|
2368
|
+
variant += "_" + src0_name;
|
|
2369
|
+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
2370
|
+
|
|
2371
|
+
defines.push_back("BYTE_HELPERS");
|
|
2372
|
+
defines.push_back("MUL_ACC_" + type_upper);
|
|
2373
|
+
defines.push_back("U32_DEQUANT_HELPERS");
|
|
2374
|
+
defines.push_back("SRC0_INNER_TYPE=u32");
|
|
2375
|
+
switch (context.src0->type) {
|
|
2376
|
+
case GGML_TYPE_IQ1_S:
|
|
2377
|
+
case GGML_TYPE_IQ1_M:
|
|
2378
|
+
case GGML_TYPE_IQ2_S:
|
|
2379
|
+
case GGML_TYPE_IQ3_S:
|
|
2380
|
+
case GGML_TYPE_IQ4_NL:
|
|
2381
|
+
case GGML_TYPE_IQ4_XS:
|
|
2382
|
+
defines.push_back(type_upper + "_GRID");
|
|
2383
|
+
break;
|
|
2384
|
+
case GGML_TYPE_IQ2_XXS:
|
|
2385
|
+
case GGML_TYPE_IQ2_XS:
|
|
2386
|
+
case GGML_TYPE_IQ3_XXS:
|
|
2387
|
+
defines.push_back(type_upper + "_GRID");
|
|
2388
|
+
defines.push_back(type_upper + "_TABLES");
|
|
2389
|
+
break;
|
|
2390
|
+
case GGML_TYPE_MXFP4:
|
|
2391
|
+
defines.push_back(type_upper + "_LUT");
|
|
2392
|
+
break;
|
|
2393
|
+
default:
|
|
2394
|
+
break;
|
|
2395
|
+
}
|
|
2396
|
+
break;
|
|
2397
|
+
}
|
|
2398
|
+
}
|
|
2399
|
+
|
|
2400
|
+
// VEC/SCALAR controls
|
|
2401
|
+
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
|
2402
|
+
|
|
2403
|
+
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
|
2404
|
+
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
|
|
2405
|
+
|
|
2406
|
+
if (key.src0_type == GGML_TYPE_Q1_0) {
|
|
2407
|
+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
|
2408
|
+
} else if (key.src0_type >= GGML_TYPE_Q2_K) {
|
|
2409
|
+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
|
|
2410
|
+
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
|
|
2411
|
+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
|
2412
|
+
}
|
|
2413
|
+
|
|
2414
|
+
// variant suffix for src1 type
|
|
2415
|
+
variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
|
|
2416
|
+
|
|
2417
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
2418
|
+
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
|
|
2419
|
+
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
|
|
2420
|
+
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
|
|
2421
|
+
if (key.vectorized) {
|
|
2422
|
+
variant += "_vectorized";
|
|
2423
|
+
}
|
|
2424
|
+
|
|
2425
|
+
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
|
|
2426
|
+
|
|
2427
|
+
auto processed = preprocessor.preprocess(shader_src, defines);
|
|
2428
|
+
|
|
2429
|
+
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
|
|
2430
|
+
decisions->wg_size = wg_size;
|
|
2431
|
+
decisions->outputs_per_wg = outputs_per_wg;
|
|
1046
2432
|
|
|
1047
2433
|
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1048
2434
|
pipeline.context = decisions;
|
|
1049
|
-
|
|
1050
|
-
return
|
|
2435
|
+
mul_mat_id_vec_pipelines[key] = pipeline;
|
|
2436
|
+
return mul_mat_id_vec_pipelines[key];
|
|
1051
2437
|
}
|
|
1052
2438
|
|
|
1053
2439
|
webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1054
2440
|
const bool is_unary = context.dst->op == GGML_OP_UNARY;
|
|
1055
2441
|
const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
|
|
1056
|
-
ggml_webgpu_unary_pipeline_key key = {
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
2442
|
+
ggml_webgpu_unary_pipeline_key key = {};
|
|
2443
|
+
key.type = context.dst->type;
|
|
2444
|
+
key.op = op;
|
|
2445
|
+
key.is_unary = is_unary;
|
|
2446
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst) || context.dst->op == GGML_OP_FILL;
|
|
2447
|
+
key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0);
|
|
1062
2448
|
|
|
1063
2449
|
auto it = unary_pipelines.find(key);
|
|
1064
2450
|
if (it != unary_pipelines.end()) {
|
|
@@ -1088,25 +2474,88 @@ class ggml_webgpu_shader_lib {
|
|
|
1088
2474
|
variant += "_inplace";
|
|
1089
2475
|
}
|
|
1090
2476
|
|
|
2477
|
+
if (op == GGML_OP_TRI) {
|
|
2478
|
+
switch (key.ttype) {
|
|
2479
|
+
case GGML_TRI_TYPE_LOWER:
|
|
2480
|
+
defines.push_back("TRI_TYPE_LOWER");
|
|
2481
|
+
variant += "_tri_type_lower";
|
|
2482
|
+
break;
|
|
2483
|
+
case GGML_TRI_TYPE_LOWER_DIAG:
|
|
2484
|
+
defines.push_back("TRI_TYPE_LOWER_DIAG");
|
|
2485
|
+
variant += "_tri_type_lower_diag";
|
|
2486
|
+
break;
|
|
2487
|
+
case GGML_TRI_TYPE_UPPER:
|
|
2488
|
+
defines.push_back("TRI_TYPE_UPPER");
|
|
2489
|
+
variant += "_tri_type_upper";
|
|
2490
|
+
break;
|
|
2491
|
+
case GGML_TRI_TYPE_UPPER_DIAG:
|
|
2492
|
+
defines.push_back("TRI_TYPE_UPPER_DIAG");
|
|
2493
|
+
variant += "_tri_upper_diag";
|
|
2494
|
+
break;
|
|
2495
|
+
default:
|
|
2496
|
+
GGML_ABORT("Unsupported ggml_tri_type for unary shader");
|
|
2497
|
+
}
|
|
2498
|
+
}
|
|
2499
|
+
|
|
1091
2500
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1092
2501
|
|
|
1093
2502
|
auto processed = preprocessor.preprocess(wgsl_unary, defines);
|
|
1094
2503
|
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1095
2504
|
decisions->wg_size = context.max_wg_size;
|
|
2505
|
+
decisions->inplace = key.inplace;
|
|
1096
2506
|
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1097
2507
|
pipeline.context = decisions;
|
|
1098
2508
|
unary_pipelines[key] = pipeline;
|
|
1099
2509
|
return unary_pipelines[key];
|
|
1100
2510
|
}
|
|
1101
2511
|
|
|
2512
|
+
webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2513
|
+
ggml_webgpu_rms_norm_mul_pipeline_key key = {};
|
|
2514
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
2515
|
+
key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
|
|
2516
|
+
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
|
|
2517
|
+
|
|
2518
|
+
auto it = rms_norm_mul_pipelines.find(key);
|
|
2519
|
+
if (it != rms_norm_mul_pipelines.end()) {
|
|
2520
|
+
return it->second;
|
|
2521
|
+
}
|
|
2522
|
+
|
|
2523
|
+
std::vector<std::string> defines;
|
|
2524
|
+
std::string op_name = "RMS_NORM_MUL";
|
|
2525
|
+
std::string variant = op_name;
|
|
2526
|
+
|
|
2527
|
+
if (key.inplace) {
|
|
2528
|
+
defines.push_back("INPLACE");
|
|
2529
|
+
variant += "_inplace";
|
|
2530
|
+
} else if (key.overlap) {
|
|
2531
|
+
defines.push_back("OVERLAP");
|
|
2532
|
+
variant += "_overlap";
|
|
2533
|
+
} else if (key.src_overlap) {
|
|
2534
|
+
defines.push_back("SRC_OVERLAP");
|
|
2535
|
+
variant += "_src_overlap";
|
|
2536
|
+
}
|
|
2537
|
+
|
|
2538
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
2539
|
+
|
|
2540
|
+
auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
|
|
2541
|
+
auto pipeline_decisions = std::make_shared<ggml_webgpu_rms_norm_mul_shader_decisions>();
|
|
2542
|
+
pipeline_decisions->wg_size = context.max_wg_size;
|
|
2543
|
+
pipeline_decisions->inplace = key.inplace;
|
|
2544
|
+
pipeline_decisions->overlap = key.overlap;
|
|
2545
|
+
pipeline_decisions->src_overlap = key.src_overlap;
|
|
2546
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2547
|
+
pipeline.context = pipeline_decisions;
|
|
2548
|
+
rms_norm_mul_pipelines[key] = pipeline;
|
|
2549
|
+
return rms_norm_mul_pipelines[key];
|
|
2550
|
+
}
|
|
2551
|
+
|
|
1102
2552
|
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1103
|
-
ggml_webgpu_binary_pipeline_key key = {
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
};
|
|
2553
|
+
ggml_webgpu_binary_pipeline_key key = {};
|
|
2554
|
+
key.type = context.dst->type;
|
|
2555
|
+
key.op = context.dst->op;
|
|
2556
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
2557
|
+
key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
|
|
2558
|
+
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
|
|
1110
2559
|
|
|
1111
2560
|
auto it = binary_pipelines.find(key);
|
|
1112
2561
|
if (it != binary_pipelines.end()) {
|
|
@@ -1145,19 +2594,54 @@ class ggml_webgpu_shader_lib {
|
|
|
1145
2594
|
|
|
1146
2595
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1147
2596
|
|
|
1148
|
-
auto processed
|
|
1149
|
-
auto
|
|
1150
|
-
|
|
2597
|
+
auto processed = preprocessor.preprocess(wgsl_binary, defines);
|
|
2598
|
+
auto pipeline_decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
|
|
2599
|
+
pipeline_decisions->wg_size = context.max_wg_size;
|
|
2600
|
+
pipeline_decisions->inplace = key.inplace;
|
|
2601
|
+
pipeline_decisions->overlap = key.overlap;
|
|
2602
|
+
pipeline_decisions->src_overlap = key.src_overlap;
|
|
2603
|
+
|
|
1151
2604
|
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1152
|
-
pipeline.context =
|
|
2605
|
+
pipeline.context = pipeline_decisions;
|
|
1153
2606
|
binary_pipelines[key] = pipeline;
|
|
1154
2607
|
return binary_pipelines[key];
|
|
1155
2608
|
}
|
|
1156
2609
|
|
|
2610
|
+
webgpu_pipeline get_add_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2611
|
+
ggml_webgpu_add_id_pipeline_key key = {};
|
|
2612
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
2613
|
+
|
|
2614
|
+
auto it = add_id_pipelines.find(key);
|
|
2615
|
+
if (it != add_id_pipelines.end()) {
|
|
2616
|
+
return it->second;
|
|
2617
|
+
}
|
|
2618
|
+
|
|
2619
|
+
std::vector<std::string> defines;
|
|
2620
|
+
std::string variant = "add_id";
|
|
2621
|
+
const char * shader_src = wgsl_add_id;
|
|
2622
|
+
|
|
2623
|
+
if (key.inplace) {
|
|
2624
|
+
defines.push_back("INPLACE");
|
|
2625
|
+
variant += "_inplace";
|
|
2626
|
+
}
|
|
2627
|
+
|
|
2628
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
2629
|
+
|
|
2630
|
+
auto processed = preprocessor.preprocess(shader_src, defines);
|
|
2631
|
+
auto pipeline_decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
2632
|
+
pipeline_decisions->wg_size = context.max_wg_size;
|
|
2633
|
+
pipeline_decisions->inplace = key.inplace;
|
|
2634
|
+
|
|
2635
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2636
|
+
pipeline.context = pipeline_decisions;
|
|
2637
|
+
add_id_pipelines[key] = pipeline;
|
|
2638
|
+
return pipeline;
|
|
2639
|
+
}
|
|
2640
|
+
|
|
1157
2641
|
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1158
|
-
ggml_webgpu_concat_pipeline_key key = {
|
|
1159
|
-
|
|
1160
|
-
|
|
2642
|
+
ggml_webgpu_concat_pipeline_key key = {};
|
|
2643
|
+
key.type = context.dst->type;
|
|
2644
|
+
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
|
|
1161
2645
|
|
|
1162
2646
|
auto it = concat_pipelines.find(key);
|
|
1163
2647
|
if (it != concat_pipelines.end()) {
|
|
@@ -1180,11 +2664,17 @@ class ggml_webgpu_shader_lib {
|
|
|
1180
2664
|
GGML_ABORT("Unsupported type for concat shader");
|
|
1181
2665
|
}
|
|
1182
2666
|
|
|
2667
|
+
if (key.src_overlap) {
|
|
2668
|
+
defines.push_back("SRC_OVERLAP");
|
|
2669
|
+
variant += "_src_overlap";
|
|
2670
|
+
}
|
|
2671
|
+
|
|
1183
2672
|
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1184
2673
|
|
|
1185
2674
|
auto processed = preprocessor.preprocess(wgsl_concat, defines);
|
|
1186
|
-
auto decisions = std::make_shared<
|
|
2675
|
+
auto decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
|
|
1187
2676
|
decisions->wg_size = context.max_wg_size;
|
|
2677
|
+
decisions->src_overlap = key.src_overlap;
|
|
1188
2678
|
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1189
2679
|
pipeline.context = decisions;
|
|
1190
2680
|
concat_pipelines[key] = pipeline;
|
|
@@ -1192,9 +2682,8 @@ class ggml_webgpu_shader_lib {
|
|
|
1192
2682
|
}
|
|
1193
2683
|
|
|
1194
2684
|
webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1195
|
-
ggml_webgpu_repeat_pipeline_key key = {
|
|
1196
|
-
|
|
1197
|
-
};
|
|
2685
|
+
ggml_webgpu_repeat_pipeline_key key = {};
|
|
2686
|
+
key.type = context.dst->type;
|
|
1198
2687
|
|
|
1199
2688
|
auto it = repeat_pipelines.find(key);
|
|
1200
2689
|
if (it != repeat_pipelines.end()) {
|
|
@@ -1233,102 +2722,551 @@ class ggml_webgpu_shader_lib {
|
|
|
1233
2722
|
}
|
|
1234
2723
|
|
|
1235
2724
|
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1236
|
-
const bool
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
ggml_webgpu_flash_attn_pipeline_key key = {
|
|
1243
|
-
|
|
1244
|
-
.
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
.
|
|
1250
|
-
|
|
2725
|
+
const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(
|
|
2726
|
+
context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2);
|
|
2727
|
+
ggml_webgpu_flash_attn_decisions decisions = {};
|
|
2728
|
+
decisions.use_sg_matrix = can_use_subgroup_matrix;
|
|
2729
|
+
decisions.q_tile = decisions.use_sg_matrix ? context.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
|
|
2730
|
+
|
|
2731
|
+
ggml_webgpu_flash_attn_pipeline_key key = {};
|
|
2732
|
+
key.common =
|
|
2733
|
+
ggml_webgpu_flash_attn_make_common_pipeline_key(context, decisions.use_sg_matrix ? context.sg_mat_k : 1u);
|
|
2734
|
+
key.common.kv_direct = decisions.use_sg_matrix && key.common.kv_direct;
|
|
2735
|
+
key.use_sg_matrix = decisions.use_sg_matrix;
|
|
2736
|
+
|
|
2737
|
+
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(
|
|
2738
|
+
context.wg_mem_limit_bytes, decisions.q_tile, decisions.use_sg_matrix ? context.sg_mat_n : 1u,
|
|
2739
|
+
key.common.head_dim_qk, key.common.head_dim_v, key.common.has_mask, key.common.kv_direct);
|
|
2740
|
+
GGML_ASSERT(max_kv_tile > 0);
|
|
2741
|
+
|
|
2742
|
+
decisions.kv_tile = decisions.use_sg_matrix ?
|
|
2743
|
+
std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES) :
|
|
2744
|
+
std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile);
|
|
2745
|
+
decisions.wg_size =
|
|
2746
|
+
decisions.use_sg_matrix ?
|
|
2747
|
+
std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE) :
|
|
2748
|
+
std::min(context.max_wg_size, std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE,
|
|
2749
|
+
GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size));
|
|
2750
|
+
|
|
2751
|
+
if (key.common.kv_direct) {
|
|
2752
|
+
decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
|
2753
|
+
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
|
2754
|
+
decisions.kv_tile -= decisions.use_sg_matrix ? context.sg_mat_n : context.min_subgroup_size;
|
|
2755
|
+
}
|
|
2756
|
+
}
|
|
1251
2757
|
|
|
1252
2758
|
auto it = flash_attn_pipelines.find(key);
|
|
1253
2759
|
if (it != flash_attn_pipelines.end()) {
|
|
1254
2760
|
return it->second;
|
|
1255
2761
|
}
|
|
1256
2762
|
|
|
2763
|
+
std::string variant = decisions.use_sg_matrix ? "flash_attn" : "flash_attn_tile";
|
|
2764
|
+
std::vector<std::string> defines = ggml_webgpu_flash_attn_common_defines(key.common, variant, decisions.q_tile,
|
|
2765
|
+
decisions.kv_tile, decisions.wg_size);
|
|
2766
|
+
const char * shader_src = nullptr;
|
|
2767
|
+
if (!key.use_sg_matrix) {
|
|
2768
|
+
shader_src = wgsl_flash_attn_tile;
|
|
2769
|
+
defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u");
|
|
2770
|
+
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
|
|
2771
|
+
variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" +
|
|
2772
|
+
std::to_string(context.max_subgroup_size);
|
|
2773
|
+
} else {
|
|
2774
|
+
shader_src = wgsl_flash_attn;
|
|
2775
|
+
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
|
2776
|
+
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
|
2777
|
+
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
|
2778
|
+
}
|
|
2779
|
+
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
|
|
2780
|
+
webgpu_pipeline pipeline =
|
|
2781
|
+
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
|
|
2782
|
+
pipeline.context = pipeline_decisions;
|
|
2783
|
+
flash_attn_pipelines[key] = pipeline;
|
|
2784
|
+
return flash_attn_pipelines[key];
|
|
2785
|
+
}
|
|
2786
|
+
|
|
2787
|
+
webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2788
|
+
ggml_webgpu_flash_attn_vec_pipeline_key key = {};
|
|
2789
|
+
key.common = ggml_webgpu_flash_attn_make_common_pipeline_key(context, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH);
|
|
2790
|
+
|
|
2791
|
+
auto it = flash_attn_vec_pipelines.find(key);
|
|
2792
|
+
if (it != flash_attn_vec_pipelines.end()) {
|
|
2793
|
+
return it->second;
|
|
2794
|
+
}
|
|
2795
|
+
|
|
2796
|
+
ggml_webgpu_flash_attn_vec_decisions decisions = {};
|
|
2797
|
+
decisions.kv_tile =
|
|
2798
|
+
ggml_webgpu_flash_attn_get_vec_kv_tile(context.wg_mem_limit_bytes, key.common.head_dim_qk,
|
|
2799
|
+
key.common.head_dim_v, key.common.has_mask, key.common.kv_direct);
|
|
2800
|
+
decisions.wg_size = context.max_subgroup_size;
|
|
2801
|
+
|
|
2802
|
+
std::string variant = "flash_attn_vec";
|
|
2803
|
+
std::vector<std::string> defines =
|
|
2804
|
+
ggml_webgpu_flash_attn_common_defines(key.common, variant, 1u, decisions.kv_tile, decisions.wg_size);
|
|
2805
|
+
if (key.common.has_mask) {
|
|
2806
|
+
defines.push_back("BLK");
|
|
2807
|
+
variant.resize(variant.size() - (sizeof("_mask") - 1));
|
|
2808
|
+
variant += "_mask_blk";
|
|
2809
|
+
}
|
|
2810
|
+
uint32_t vec_ne = 1u;
|
|
2811
|
+
if (key.common.k_type == GGML_TYPE_F16 && key.common.v_type == GGML_TYPE_F16 &&
|
|
2812
|
+
key.common.head_dim_qk == key.common.head_dim_v) {
|
|
2813
|
+
switch (key.common.head_dim_qk) {
|
|
2814
|
+
case 64:
|
|
2815
|
+
case 192:
|
|
2816
|
+
case 576:
|
|
2817
|
+
vec_ne = 2u;
|
|
2818
|
+
break;
|
|
2819
|
+
case 96:
|
|
2820
|
+
vec_ne = 4u;
|
|
2821
|
+
break;
|
|
2822
|
+
default:
|
|
2823
|
+
break;
|
|
2824
|
+
}
|
|
2825
|
+
}
|
|
2826
|
+
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
|
|
2827
|
+
|
|
2828
|
+
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>(decisions);
|
|
2829
|
+
webgpu_pipeline pipeline =
|
|
2830
|
+
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant);
|
|
2831
|
+
pipeline.context = pipeline_decisions;
|
|
2832
|
+
flash_attn_vec_pipelines[key] = pipeline;
|
|
2833
|
+
return flash_attn_vec_pipelines[key];
|
|
2834
|
+
}
|
|
2835
|
+
|
|
2836
|
+
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) {
|
|
2837
|
+
ggml_webgpu_flash_attn_blk_pipeline_key key = {};
|
|
2838
|
+
key.kv_tile = kv_tile;
|
|
2839
|
+
auto it = flash_attn_blk_pipelines.find(key);
|
|
2840
|
+
if (it != flash_attn_blk_pipelines.end()) {
|
|
2841
|
+
return it->second;
|
|
2842
|
+
}
|
|
2843
|
+
|
|
1257
2844
|
std::vector<std::string> defines;
|
|
1258
|
-
std::string variant = "
|
|
2845
|
+
std::string variant = "flash_attn_vec_blk";
|
|
1259
2846
|
|
|
1260
|
-
|
|
2847
|
+
defines.push_back(std::string("KV_TILE=") + std::to_string(key.kv_tile));
|
|
2848
|
+
variant += std::string("_kvt") + std::to_string(key.kv_tile);
|
|
2849
|
+
|
|
2850
|
+
uint32_t wg_size = 1;
|
|
2851
|
+
while ((wg_size << 1) <= context.max_wg_size) {
|
|
2852
|
+
wg_size <<= 1;
|
|
2853
|
+
}
|
|
2854
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
2855
|
+
variant += std::string("_wg") + std::to_string(wg_size);
|
|
2856
|
+
|
|
2857
|
+
webgpu_pipeline pipeline =
|
|
2858
|
+
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_blk, defines), variant);
|
|
2859
|
+
flash_attn_blk_pipelines[key] = pipeline;
|
|
2860
|
+
return flash_attn_blk_pipelines[key];
|
|
2861
|
+
}
|
|
2862
|
+
|
|
2863
|
+
webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2864
|
+
ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {};
|
|
2865
|
+
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
|
2866
|
+
key.dst_type = context.dst->type;
|
|
2867
|
+
key.wg_size = context.max_wg_size;
|
|
2868
|
+
auto it = flash_attn_vec_reduce_pipelines.find(key);
|
|
2869
|
+
if (it != flash_attn_vec_reduce_pipelines.end()) {
|
|
2870
|
+
return it->second;
|
|
2871
|
+
}
|
|
2872
|
+
|
|
2873
|
+
std::vector<std::string> defines;
|
|
2874
|
+
std::string variant = "flash_attn_vec_reduce";
|
|
2875
|
+
|
|
2876
|
+
switch (key.dst_type) {
|
|
1261
2877
|
case GGML_TYPE_F32:
|
|
1262
|
-
defines.push_back("
|
|
2878
|
+
defines.push_back("DST_F32");
|
|
1263
2879
|
break;
|
|
1264
2880
|
case GGML_TYPE_F16:
|
|
1265
|
-
defines.push_back("
|
|
2881
|
+
defines.push_back("DST_F16");
|
|
1266
2882
|
break;
|
|
1267
|
-
|
|
1268
|
-
|
|
2883
|
+
default:
|
|
2884
|
+
GGML_ABORT("Unsupported dst type for flash attention vec reduce shader");
|
|
2885
|
+
}
|
|
2886
|
+
variant += std::string("_dst") + ggml_type_name(key.dst_type);
|
|
2887
|
+
|
|
2888
|
+
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
|
2889
|
+
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
|
2890
|
+
|
|
2891
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
2892
|
+
variant += std::string("_wg") + std::to_string(context.max_wg_size);
|
|
2893
|
+
|
|
2894
|
+
webgpu_pipeline pipeline =
|
|
2895
|
+
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_reduce, defines), variant);
|
|
2896
|
+
flash_attn_vec_reduce_pipelines[key] = pipeline;
|
|
2897
|
+
return flash_attn_vec_reduce_pipelines[key];
|
|
2898
|
+
}
|
|
2899
|
+
|
|
2900
|
+
webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2901
|
+
ggml_webgpu_cpy_pipeline_key key = {};
|
|
2902
|
+
key.src_type = context.src0->type;
|
|
2903
|
+
key.dst_type = context.dst->type;
|
|
2904
|
+
|
|
2905
|
+
auto it = cpy_pipelines.find(key);
|
|
2906
|
+
if (it != cpy_pipelines.end()) {
|
|
2907
|
+
return it->second;
|
|
2908
|
+
}
|
|
2909
|
+
|
|
2910
|
+
std::vector<std::string> defines;
|
|
2911
|
+
std::string variant = "cpy";
|
|
2912
|
+
|
|
2913
|
+
switch (key.src_type) {
|
|
2914
|
+
case GGML_TYPE_F32:
|
|
2915
|
+
defines.push_back("SRC_F32");
|
|
2916
|
+
variant += "_f32";
|
|
1269
2917
|
break;
|
|
1270
|
-
case
|
|
1271
|
-
defines.push_back("
|
|
2918
|
+
case GGML_TYPE_F16:
|
|
2919
|
+
defines.push_back("SRC_F16");
|
|
2920
|
+
variant += "_f16";
|
|
2921
|
+
break;
|
|
2922
|
+
default:
|
|
2923
|
+
GGML_ABORT("Unsupported src type for cpy shader");
|
|
2924
|
+
}
|
|
2925
|
+
|
|
2926
|
+
switch (key.dst_type) {
|
|
2927
|
+
case GGML_TYPE_F32:
|
|
2928
|
+
defines.push_back("DST_F32");
|
|
2929
|
+
variant += "_f32";
|
|
2930
|
+
break;
|
|
2931
|
+
case GGML_TYPE_F16:
|
|
2932
|
+
defines.push_back("DST_F16");
|
|
2933
|
+
variant += "_f16";
|
|
2934
|
+
break;
|
|
2935
|
+
case GGML_TYPE_I32:
|
|
2936
|
+
defines.push_back("DST_I32");
|
|
2937
|
+
variant += "_i32";
|
|
2938
|
+
break;
|
|
2939
|
+
default:
|
|
2940
|
+
GGML_ABORT("Unsupported dst type for cpy shader");
|
|
2941
|
+
}
|
|
2942
|
+
|
|
2943
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
2944
|
+
|
|
2945
|
+
auto processed = preprocessor.preprocess(wgsl_cpy, defines);
|
|
2946
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
2947
|
+
decisions->wg_size = context.max_wg_size;
|
|
2948
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2949
|
+
pipeline.context = decisions;
|
|
2950
|
+
cpy_pipelines[key] = pipeline;
|
|
2951
|
+
return cpy_pipelines[key];
|
|
2952
|
+
}
|
|
2953
|
+
|
|
2954
|
+
webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2955
|
+
ggml_webgpu_glu_pipeline_key key = {};
|
|
2956
|
+
key.glu_op = ggml_get_glu_op(context.dst);
|
|
2957
|
+
key.type = context.dst->type;
|
|
2958
|
+
key.split = (context.src1 != nullptr);
|
|
2959
|
+
|
|
2960
|
+
auto it = glu_pipelines.find(key);
|
|
2961
|
+
if (it != glu_pipelines.end()) {
|
|
2962
|
+
return it->second;
|
|
2963
|
+
}
|
|
2964
|
+
|
|
2965
|
+
std::vector<std::string> defines;
|
|
2966
|
+
std::string variant = "glu";
|
|
2967
|
+
|
|
2968
|
+
switch (key.glu_op) {
|
|
2969
|
+
case GGML_GLU_OP_REGLU:
|
|
2970
|
+
defines.push_back("OP_REGLU");
|
|
2971
|
+
variant += "_reglu";
|
|
2972
|
+
break;
|
|
2973
|
+
case GGML_GLU_OP_GEGLU:
|
|
2974
|
+
defines.push_back("OP_GEGLU");
|
|
2975
|
+
variant += "_geglu";
|
|
2976
|
+
break;
|
|
2977
|
+
case GGML_GLU_OP_SWIGLU:
|
|
2978
|
+
defines.push_back("OP_SWIGLU");
|
|
2979
|
+
variant += "_swiglu";
|
|
2980
|
+
break;
|
|
2981
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
2982
|
+
defines.push_back("OP_SWIGLU_OAI");
|
|
2983
|
+
variant += "_swiglu_oai";
|
|
2984
|
+
break;
|
|
2985
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
2986
|
+
defines.push_back("OP_GEGLU_ERF");
|
|
2987
|
+
variant += "_geglu_erf";
|
|
2988
|
+
break;
|
|
2989
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
2990
|
+
defines.push_back("OP_GEGLU_QUICK");
|
|
2991
|
+
variant += "_geglu_quick";
|
|
2992
|
+
break;
|
|
2993
|
+
default:
|
|
2994
|
+
GGML_ABORT("Unsupported GLU op");
|
|
2995
|
+
}
|
|
2996
|
+
switch (key.type) {
|
|
2997
|
+
case GGML_TYPE_F32:
|
|
2998
|
+
defines.push_back("TYPE_F32");
|
|
2999
|
+
variant += "_f32";
|
|
3000
|
+
break;
|
|
3001
|
+
case GGML_TYPE_F16:
|
|
3002
|
+
defines.push_back("TYPE_F16");
|
|
3003
|
+
variant += "_f16";
|
|
3004
|
+
break;
|
|
3005
|
+
default:
|
|
3006
|
+
GGML_ABORT("Unsupported type for GLU shader");
|
|
3007
|
+
}
|
|
3008
|
+
|
|
3009
|
+
if (key.split) {
|
|
3010
|
+
variant += "_split";
|
|
3011
|
+
} else {
|
|
3012
|
+
defines.push_back("NO_SPLIT");
|
|
3013
|
+
}
|
|
3014
|
+
|
|
3015
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
3016
|
+
|
|
3017
|
+
auto processed = preprocessor.preprocess(wgsl_glu, defines);
|
|
3018
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
3019
|
+
decisions->wg_size = context.max_wg_size;
|
|
3020
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
3021
|
+
pipeline.context = decisions;
|
|
3022
|
+
glu_pipelines[key] = pipeline;
|
|
3023
|
+
return glu_pipelines[key];
|
|
3024
|
+
}
|
|
3025
|
+
|
|
3026
|
+
webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
3027
|
+
ggml_webgpu_rope_pipeline_key key = {};
|
|
3028
|
+
key.type = context.dst->type;
|
|
3029
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
3030
|
+
key.has_ff = (context.src2 != nullptr);
|
|
3031
|
+
|
|
3032
|
+
auto it = rope_pipelines.find(key);
|
|
3033
|
+
if (it != rope_pipelines.end()) {
|
|
3034
|
+
return it->second;
|
|
3035
|
+
}
|
|
3036
|
+
|
|
3037
|
+
std::vector<std::string> defines;
|
|
3038
|
+
std::string variant = "rope";
|
|
3039
|
+
|
|
3040
|
+
switch (key.type) {
|
|
3041
|
+
case GGML_TYPE_F32:
|
|
3042
|
+
defines.push_back("TYPE_F32");
|
|
3043
|
+
variant += "_f32";
|
|
3044
|
+
break;
|
|
3045
|
+
case GGML_TYPE_F16:
|
|
3046
|
+
defines.push_back("TYPE_F16");
|
|
3047
|
+
variant += "_f16";
|
|
1272
3048
|
break;
|
|
1273
3049
|
default:
|
|
1274
|
-
GGML_ABORT("Unsupported
|
|
3050
|
+
GGML_ABORT("Unsupported type for ROPE shader");
|
|
3051
|
+
}
|
|
3052
|
+
|
|
3053
|
+
if (key.inplace) {
|
|
3054
|
+
defines.push_back("INPLACE");
|
|
3055
|
+
variant += "_inplace";
|
|
3056
|
+
}
|
|
3057
|
+
|
|
3058
|
+
if (key.has_ff) {
|
|
3059
|
+
defines.push_back("FF_FUNC");
|
|
3060
|
+
variant += "_ff";
|
|
1275
3061
|
}
|
|
1276
|
-
|
|
3062
|
+
|
|
3063
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
3064
|
+
|
|
3065
|
+
auto processed = preprocessor.preprocess(wgsl_rope, defines);
|
|
3066
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
3067
|
+
decisions->wg_size = context.max_wg_size;
|
|
3068
|
+
decisions->inplace = key.inplace;
|
|
3069
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
3070
|
+
pipeline.context = decisions;
|
|
3071
|
+
rope_pipelines[key] = pipeline;
|
|
3072
|
+
return rope_pipelines[key];
|
|
3073
|
+
}
|
|
3074
|
+
|
|
3075
|
+
webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
3076
|
+
ggml_webgpu_soft_max_pipeline_key key = {};
|
|
3077
|
+
key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32;
|
|
3078
|
+
key.has_mask = (context.src1 != nullptr);
|
|
3079
|
+
key.has_sink = (context.src2 != nullptr);
|
|
3080
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
3081
|
+
|
|
3082
|
+
auto it = soft_max_pipelines.find(key);
|
|
3083
|
+
if (it != soft_max_pipelines.end()) {
|
|
3084
|
+
return it->second;
|
|
3085
|
+
}
|
|
3086
|
+
|
|
3087
|
+
std::vector<std::string> defines;
|
|
3088
|
+
std::string variant = "soft_max";
|
|
1277
3089
|
|
|
1278
3090
|
if (key.has_mask) {
|
|
1279
|
-
defines.push_back("
|
|
1280
|
-
|
|
3091
|
+
defines.push_back("HAS_MASK");
|
|
3092
|
+
switch (key.mask_type) {
|
|
3093
|
+
case GGML_TYPE_F32:
|
|
3094
|
+
defines.push_back("MASK_F32");
|
|
3095
|
+
variant += "_mask_f32";
|
|
3096
|
+
break;
|
|
3097
|
+
case GGML_TYPE_F16:
|
|
3098
|
+
defines.push_back("MASK_F16");
|
|
3099
|
+
variant += "_mask_f16";
|
|
3100
|
+
break;
|
|
3101
|
+
default:
|
|
3102
|
+
GGML_ABORT("Unsupported type for SOFT_MAX shader");
|
|
3103
|
+
}
|
|
1281
3104
|
}
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
3105
|
+
|
|
3106
|
+
if (key.has_sink) {
|
|
3107
|
+
defines.push_back("HAS_SINK");
|
|
3108
|
+
variant += "_sink";
|
|
1285
3109
|
}
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
3110
|
+
|
|
3111
|
+
if (key.inplace) {
|
|
3112
|
+
defines.push_back("INPLACE");
|
|
3113
|
+
variant += "_inplace";
|
|
1289
3114
|
}
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
3115
|
+
|
|
3116
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
3117
|
+
|
|
3118
|
+
auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
|
|
3119
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
3120
|
+
decisions->wg_size = context.max_wg_size;
|
|
3121
|
+
decisions->inplace = key.inplace;
|
|
3122
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
3123
|
+
pipeline.context = decisions;
|
|
3124
|
+
soft_max_pipelines[key] = pipeline;
|
|
3125
|
+
return soft_max_pipelines[key];
|
|
3126
|
+
}
|
|
3127
|
+
|
|
3128
|
+
webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
3129
|
+
ggml_webgpu_conv2d_pipeline_key key = {};
|
|
3130
|
+
key.weight_type = context.src0->type;
|
|
3131
|
+
key.input_type = context.src1->type;
|
|
3132
|
+
key.output_type = context.dst->type;
|
|
3133
|
+
|
|
3134
|
+
auto it = conv2d_pipelines.find(key);
|
|
3135
|
+
if (it != conv2d_pipelines.end()) {
|
|
3136
|
+
return it->second;
|
|
1293
3137
|
}
|
|
1294
3138
|
|
|
1295
|
-
|
|
1296
|
-
variant
|
|
3139
|
+
std::vector<std::string> defines;
|
|
3140
|
+
std::string variant = "conv_2d";
|
|
3141
|
+
|
|
3142
|
+
auto push_type_defines = [&](const char * prefix, ggml_type type) {
|
|
3143
|
+
std::string s_prefix = prefix;
|
|
3144
|
+
if (type == GGML_TYPE_F32) {
|
|
3145
|
+
defines.push_back(s_prefix + "_F32");
|
|
3146
|
+
} else if (type == GGML_TYPE_F16) {
|
|
3147
|
+
defines.push_back(s_prefix + "_F16");
|
|
3148
|
+
} else {
|
|
3149
|
+
GGML_ABORT("Unsupported type for CONV_2D shader");
|
|
3150
|
+
}
|
|
3151
|
+
};
|
|
1297
3152
|
|
|
1298
|
-
|
|
1299
|
-
|
|
3153
|
+
push_type_defines("WEIGHT", key.weight_type);
|
|
3154
|
+
push_type_defines("INPUT", key.input_type);
|
|
3155
|
+
push_type_defines("OUTPUT", key.output_type);
|
|
3156
|
+
|
|
3157
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1300
3158
|
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
3159
|
+
auto processed = preprocessor.preprocess(wgsl_conv2d, defines);
|
|
3160
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
3161
|
+
decisions->wg_size = context.max_wg_size;
|
|
3162
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
3163
|
+
pipeline.context = decisions;
|
|
3164
|
+
conv2d_pipelines[key] = pipeline;
|
|
3165
|
+
return conv2d_pipelines[key];
|
|
3166
|
+
}
|
|
3167
|
+
|
|
3168
|
+
webgpu_pipeline get_im2col_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
3169
|
+
ggml_webgpu_im2col_pipeline_key key = {};
|
|
3170
|
+
key.input_type = context.src1->type;
|
|
3171
|
+
key.output_type = context.dst->type;
|
|
3172
|
+
|
|
3173
|
+
auto it = im2col_pipelines.find(key);
|
|
3174
|
+
if (it != im2col_pipelines.end()) {
|
|
3175
|
+
return it->second;
|
|
3176
|
+
}
|
|
3177
|
+
|
|
3178
|
+
std::vector<std::string> defines;
|
|
3179
|
+
std::string variant = "im2col";
|
|
3180
|
+
|
|
3181
|
+
auto push_type_defines = [&](const char * prefix, ggml_type type) {
|
|
3182
|
+
std::string s_prefix = prefix;
|
|
3183
|
+
if (type == GGML_TYPE_F32) {
|
|
3184
|
+
defines.push_back(s_prefix + "_F32");
|
|
3185
|
+
} else if (type == GGML_TYPE_F16) {
|
|
3186
|
+
defines.push_back(s_prefix + "_F16");
|
|
3187
|
+
} else {
|
|
3188
|
+
GGML_ABORT("Unsupported type for IM2COL shader");
|
|
1313
3189
|
}
|
|
3190
|
+
};
|
|
3191
|
+
|
|
3192
|
+
push_type_defines("INPUT", key.input_type);
|
|
3193
|
+
push_type_defines("OUTPUT", key.output_type);
|
|
3194
|
+
|
|
3195
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
3196
|
+
|
|
3197
|
+
auto processed = preprocessor.preprocess(wgsl_im2col, defines);
|
|
3198
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
3199
|
+
decisions->wg_size = context.max_wg_size;
|
|
3200
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
3201
|
+
pipeline.context = decisions;
|
|
3202
|
+
im2col_pipelines[key] = pipeline;
|
|
3203
|
+
return im2col_pipelines[key];
|
|
3204
|
+
}
|
|
3205
|
+
|
|
3206
|
+
webgpu_pipeline get_upscale_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
3207
|
+
const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(context.dst, 0);
|
|
3208
|
+
const uint32_t base_mode = mode_flags & 0xFFu;
|
|
3209
|
+
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS) != 0u;
|
|
3210
|
+
|
|
3211
|
+
ggml_webgpu_upscale_pipeline_key key = {};
|
|
3212
|
+
key.input_type = context.src0->type;
|
|
3213
|
+
key.output_type = context.dst->type;
|
|
3214
|
+
key.base_mode = base_mode;
|
|
3215
|
+
key.antialias = antialias;
|
|
3216
|
+
|
|
3217
|
+
auto it = upscale_pipelines.find(key);
|
|
3218
|
+
if (it != upscale_pipelines.end()) {
|
|
3219
|
+
return it->second;
|
|
1314
3220
|
}
|
|
1315
3221
|
|
|
1316
|
-
|
|
1317
|
-
|
|
3222
|
+
std::vector<std::string> defines;
|
|
3223
|
+
std::string variant = "upscale";
|
|
1318
3224
|
|
|
1319
|
-
|
|
1320
|
-
|
|
3225
|
+
if (key.input_type == GGML_TYPE_F16) {
|
|
3226
|
+
defines.push_back("SRC_F16");
|
|
3227
|
+
variant += "_src_f16";
|
|
3228
|
+
} else {
|
|
3229
|
+
variant += "_src_f32";
|
|
3230
|
+
}
|
|
1321
3231
|
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
3232
|
+
if (key.output_type == GGML_TYPE_F16) {
|
|
3233
|
+
defines.push_back("DST_F16");
|
|
3234
|
+
variant += "_dst_f16";
|
|
3235
|
+
} else {
|
|
3236
|
+
variant += "_dst_f32";
|
|
3237
|
+
}
|
|
1327
3238
|
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
3239
|
+
switch (base_mode) {
|
|
3240
|
+
case GGML_SCALE_MODE_NEAREST:
|
|
3241
|
+
defines.push_back("NEAREST");
|
|
3242
|
+
variant += "_nearest";
|
|
3243
|
+
break;
|
|
3244
|
+
case GGML_SCALE_MODE_BILINEAR:
|
|
3245
|
+
defines.push_back("BILINEAR");
|
|
3246
|
+
variant += "_bilinear";
|
|
3247
|
+
break;
|
|
3248
|
+
case GGML_SCALE_MODE_BICUBIC:
|
|
3249
|
+
defines.push_back("BICUBIC");
|
|
3250
|
+
variant += "_bicubic";
|
|
3251
|
+
break;
|
|
3252
|
+
default:
|
|
3253
|
+
GGML_ABORT("Unsupported upscale mode");
|
|
3254
|
+
}
|
|
3255
|
+
|
|
3256
|
+
if (antialias) {
|
|
3257
|
+
defines.push_back("ANTIALIAS");
|
|
3258
|
+
variant += "_aa";
|
|
3259
|
+
}
|
|
3260
|
+
|
|
3261
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
3262
|
+
|
|
3263
|
+
auto processed = preprocessor.preprocess(wgsl_upscale, defines);
|
|
3264
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
3265
|
+
decisions->wg_size = context.max_wg_size;
|
|
3266
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
3267
|
+
pipeline.context = decisions;
|
|
3268
|
+
upscale_pipelines[key] = pipeline;
|
|
3269
|
+
return upscale_pipelines[key];
|
|
1332
3270
|
}
|
|
1333
3271
|
|
|
1334
3272
|
private:
|
|
@@ -1350,25 +3288,6 @@ class ggml_webgpu_shader_lib {
|
|
|
1350
3288
|
pipeline_desc.layout = nullptr; // nullptr means auto layout
|
|
1351
3289
|
return { device.CreateComputePipeline(&pipeline_desc), label };
|
|
1352
3290
|
}
|
|
1353
|
-
|
|
1354
|
-
static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
|
|
1355
|
-
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
|
1356
|
-
const size_t q_tile = context.sg_mat_m;
|
|
1357
|
-
const size_t base_q_bytes =
|
|
1358
|
-
(context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
|
1359
|
-
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
|
1360
|
-
size_t bytes_per_kv = 0;
|
|
1361
|
-
if (!context.key.kv_direct) {
|
|
1362
|
-
bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
|
|
1363
|
-
}
|
|
1364
|
-
if (context.key.has_mask) {
|
|
1365
|
-
bytes_per_kv += q_tile;
|
|
1366
|
-
}
|
|
1367
|
-
bytes_per_kv += q_tile;
|
|
1368
|
-
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
|
1369
|
-
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
|
1370
|
-
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
|
1371
|
-
}
|
|
1372
3291
|
};
|
|
1373
3292
|
|
|
1374
3293
|
#endif // GGML_WEBGPU_SHADER_LIB_HPP
|