whispercpp 1.3.5 → 1.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.document +3 -0
- data/.rdoc_options +2 -0
- data/LICENSE +1 -1
- data/README.md +133 -3
- data/Rakefile +18 -3
- data/ext/dependencies.rb +10 -4
- data/ext/dependencies_for_windows.rb +17 -0
- data/ext/extconf.rb +20 -7
- data/ext/options.rb +54 -14
- data/ext/options_for_windows.rb +51 -0
- data/ext/ruby_whisper.c +56 -46
- data/ext/ruby_whisper.h +165 -2
- data/ext/ruby_whisper_context.c +297 -126
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_log_queue.c +180 -0
- data/ext/ruby_whisper_log_settable.h +47 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_parakeet.c +49 -0
- data/ext/ruby_whisper_parakeet_context.c +304 -0
- data/ext/ruby_whisper_parakeet_context_params.c +117 -0
- data/ext/ruby_whisper_parakeet_model.c +84 -0
- data/ext/ruby_whisper_parakeet_params.c +548 -0
- data/ext/ruby_whisper_parakeet_segment.c +157 -0
- data/ext/ruby_whisper_parakeet_token.c +188 -0
- data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
- data/ext/ruby_whisper_params.c +256 -66
- data/ext/ruby_whisper_segment.c +6 -7
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +46 -16
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +41 -3
- data/ext/sources/CMakePresets.json +95 -0
- data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
- data/ext/sources/cmake/parakeet.pc.in +10 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/whisper.pc.in +1 -1
- data/ext/sources/examples/CMakeLists.txt +4 -2
- data/ext/sources/examples/bench/bench.cpp +24 -19
- data/ext/sources/examples/cli/cli.cpp +51 -9
- data/ext/sources/examples/common-ggml.cpp +4 -0
- data/ext/sources/examples/common-whisper.cpp +139 -67
- data/ext/sources/examples/common-whisper.h +11 -0
- data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
- data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
- data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
- data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
- data/ext/sources/examples/server/server.cpp +213 -163
- data/ext/sources/ggml/CMakeLists.txt +29 -15
- data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
- data/ext/sources/ggml/include/ggml-alloc.h +1 -0
- data/ext/sources/ggml/include/ggml-backend.h +73 -11
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-cuda.h +3 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +8 -3
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +155 -16
- data/ext/sources/ggml/include/gguf.h +10 -2
- data/ext/sources/ggml/src/CMakeLists.txt +25 -5
- data/ext/sources/ggml/src/ggml-alloc.c +9 -10
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
- data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
- data/ext/sources/ggml/src/ggml-common.h +22 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
- data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
- data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
- data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
- data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
- data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
- data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
- data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
- data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
- data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
- data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
- data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
- data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
- data/ext/sources/ggml/src/ggml-impl.h +68 -1
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
- data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
- data/ext/sources/ggml/src/ggml-quants.c +385 -119
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
- data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
- data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
- data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
- data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
- data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
- data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
- data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
- data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
- data/ext/sources/ggml/src/ggml.c +268 -52
- data/ext/sources/ggml/src/gguf.cpp +377 -47
- data/ext/sources/include/parakeet.h +342 -0
- data/ext/sources/include/whisper.h +10 -0
- data/ext/sources/media/matmul.png +0 -0
- data/ext/sources/src/CMakeLists.txt +23 -0
- data/ext/sources/src/parakeet-arch.h +188 -0
- data/ext/sources/src/parakeet.cpp +3838 -0
- data/ext/sources/src/whisper.cpp +62 -40
- data/extsources.rb +26 -10
- data/lib/whisper/log_settable.rb +36 -0
- data/lib/whisper/model/uri.rb +13 -1
- data/lib/whisper/output.rb +74 -0
- data/sig/whisper.rbs +445 -55
- data/test/helper.rb +2 -0
- data/test/jfk_reader/jfk_reader.c +50 -7
- data/test/test_callback.rb +1 -0
- data/test/test_context_params.rb +82 -0
- data/test/test_package.rb +6 -5
- data/test/test_parakeet.rb +28 -0
- data/test/test_parakeet_callback.rb +107 -0
- data/test/test_parakeet_context.rb +116 -0
- data/test/test_parakeet_context_params.rb +24 -0
- data/test/test_parakeet_model.rb +21 -0
- data/test/test_parakeet_params.rb +78 -0
- data/test/test_parakeet_segment.rb +42 -0
- data/test/test_parakeet_token.rb +73 -0
- data/test/test_params.rb +2 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_vad_segment.rb +1 -1
- data/test/test_whisper.rb +44 -6
- data/whispercpp.gemspec +2 -2
- metadata +426 -280
- data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
- data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
- data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
- data/ext/sources/bindings/javascript/package.json +0 -26
- data/ext/sources/bindings/javascript/whisper.js +0 -19
- data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
- data/ext/sources/examples/addon.node/addon.cpp +0 -557
- data/ext/sources/examples/addon.node/index.js +0 -59
- data/ext/sources/examples/addon.node/package.json +0 -16
- data/ext/sources/examples/addon.node/vad-example.js +0 -132
- data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
- data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
- data/ext/sources/examples/coi-serviceworker.js +0 -146
- data/ext/sources/examples/command/CMakeLists.txt +0 -10
- data/ext/sources/examples/command/command.cpp +0 -802
- data/ext/sources/examples/command/commands.txt +0 -9
- data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
- data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
- data/ext/sources/examples/generate-karaoke.sh +0 -57
- data/ext/sources/examples/helpers.js +0 -191
- data/ext/sources/examples/livestream.sh +0 -112
- data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
- data/ext/sources/examples/lsp/lsp.cpp +0 -471
- data/ext/sources/examples/lsp/whisper.vim +0 -362
- data/ext/sources/examples/python/test_whisper_processor.py +0 -7
- data/ext/sources/examples/python/whisper_processor.py +0 -54
- data/ext/sources/examples/server/bench.js +0 -29
- data/ext/sources/examples/server.py +0 -120
- data/ext/sources/examples/stream/CMakeLists.txt +0 -10
- data/ext/sources/examples/stream/stream.cpp +0 -437
- data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
- data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
- data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
- data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
- data/ext/sources/examples/sycl/build.sh +0 -22
- data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
- data/ext/sources/examples/sycl/run-whisper.sh +0 -17
- data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
- data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
- data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
- data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
- data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
- data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
- data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
- data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
- data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
- data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
- data/ext/sources/examples/talk-llama/llama-context.h +0 -360
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
- data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
- data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
- data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
- data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
- data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
- data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
- data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
- data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
- data/ext/sources/examples/talk-llama/llama-io.h +0 -35
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
- data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
- data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
- data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
- data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
- data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
- data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
- data/ext/sources/examples/talk-llama/llama-model.h +0 -544
- data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
- data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
- data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
- data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
- data/ext/sources/examples/talk-llama/llama.h +0 -1540
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
- data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
- data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
- data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
- data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
- data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
- data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
- data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
- data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
- data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
- data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
- data/ext/sources/examples/talk-llama/models/models.h +0 -569
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
- data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
- data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
- data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
- data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
- data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
- data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
- data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
- data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
- data/ext/sources/examples/talk-llama/speak +0 -40
- data/ext/sources/examples/talk-llama/speak.bat +0 -1
- data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
- data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
- data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
- data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
- data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
- data/ext/sources/examples/talk-llama/unicode.h +0 -111
- data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
- data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
- data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
- data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
- data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
- data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
- data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
- data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
- data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
- data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
- data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
- data/ext/sources/tests/CMakeLists.txt +0 -112
- data/ext/sources/tests/earnings21/eval.mk +0 -58
- data/ext/sources/tests/earnings21/eval.py +0 -68
- data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
- data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
- data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
- data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
- data/ext/sources/tests/earnings21/requirements.txt +0 -6
- data/ext/sources/tests/en-0-ref.txt +0 -1
- data/ext/sources/tests/en-1-ref.txt +0 -1
- data/ext/sources/tests/en-2-ref.txt +0 -1
- data/ext/sources/tests/es-0-ref.txt +0 -1
- data/ext/sources/tests/librispeech/eval.mk +0 -39
- data/ext/sources/tests/librispeech/eval.py +0 -47
- data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
- data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
- data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
- data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
- data/ext/sources/tests/librispeech/requirements.txt +0 -6
- data/ext/sources/tests/run-tests.sh +0 -130
- data/ext/sources/tests/test-c.c +0 -3
- data/ext/sources/tests/test-vad-full.cpp +0 -56
- data/ext/sources/tests/test-vad.cpp +0 -83
- data/ext/sources/tests/test-whisper.js +0 -58
- data/lib/whisper/context.rb +0 -15
- data/lib/whisper/segment.rb +0 -58
|
@@ -1,169 +1,3293 @@
|
|
|
1
1
|
#ifndef GGML_WEBGPU_SHADER_LIB_HPP
|
|
2
2
|
#define GGML_WEBGPU_SHADER_LIB_HPP
|
|
3
3
|
|
|
4
|
+
#include "ggml-impl.h"
|
|
5
|
+
#include "ggml-wgsl-shaders.hpp"
|
|
4
6
|
#include "ggml.h"
|
|
5
7
|
#include "pre_wgsl.hpp"
|
|
6
8
|
|
|
9
|
+
#include <webgpu/webgpu_cpp.h>
|
|
10
|
+
|
|
11
|
+
#include <algorithm>
|
|
12
|
+
#include <memory>
|
|
7
13
|
#include <string>
|
|
14
|
+
#include <unordered_map>
|
|
8
15
|
#include <vector>
|
|
9
16
|
|
|
10
17
|
#define GGML_WEBGPU_F16_SIZE_BYTES 2
|
|
11
18
|
#define GGML_WEBGPU_F32_SIZE_BYTES 4
|
|
19
|
+
#define GGML_WEBGPU_I32_SIZE_BYTES 4
|
|
12
20
|
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
|
|
21
|
+
#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN 20u
|
|
22
|
+
#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE 32u
|
|
23
|
+
#define GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE 64u
|
|
13
24
|
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
|
|
14
25
|
// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
|
|
15
26
|
#define GGML_WEBGPU_KV_SEQ_PAD 256u
|
|
16
27
|
|
|
17
|
-
|
|
18
|
-
|
|
28
|
+
#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
|
|
29
|
+
|
|
30
|
+
// Matrix multiplication parameters
|
|
31
|
+
|
|
32
|
+
// Register tiling parameters
|
|
33
|
+
#define WEBGPU_MUL_MAT_TILE_M 4
|
|
34
|
+
#define WEBGPU_MUL_MAT_TILE_N 4
|
|
35
|
+
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
|
|
36
|
+
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
|
|
37
|
+
#define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8
|
|
38
|
+
#define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32
|
|
39
|
+
|
|
40
|
+
// Subgroup matrix parameters
|
|
41
|
+
// The number of subgroups in the M dimension
|
|
42
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
|
|
43
|
+
// The number of subgroups in the N dimension
|
|
44
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_N 4
|
|
45
|
+
// The number of subgroup matrices each subgroup accumulates over
|
|
46
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
|
|
47
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
|
48
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32
|
|
49
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32
|
|
50
|
+
|
|
51
|
+
// Matrix-vector multiplication parameters
|
|
52
|
+
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
|
53
|
+
|
|
54
|
+
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4
|
|
55
|
+
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4
|
|
56
|
+
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4
|
|
57
|
+
|
|
58
|
+
// default size for reg-tile matrix multiplication
|
|
59
|
+
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
|
60
|
+
|
|
61
|
+
// Same hash combine function as in boost
|
|
62
|
+
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
|
|
63
|
+
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
// Calculates base address of a tensor ignoring the fake base pointer
|
|
67
|
+
inline uintptr_t ggml_webgpu_tensor_addr(const ggml_tensor * tensor) {
|
|
68
|
+
const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor;
|
|
69
|
+
return (uintptr_t) base_tensor->data + tensor->view_offs;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
inline bool ggml_webgpu_tensor_equal(const ggml_tensor * a, const ggml_tensor * b) {
|
|
73
|
+
return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) == ggml_webgpu_tensor_addr(b);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
inline bool ggml_webgpu_tensor_overlap(const ggml_tensor * a, const ggml_tensor * b) {
|
|
77
|
+
return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) < ggml_webgpu_tensor_addr(b) + ggml_nbytes(b) &&
|
|
78
|
+
ggml_webgpu_tensor_addr(b) < ggml_webgpu_tensor_addr(a) + ggml_nbytes(a);
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
struct ggml_webgpu_shader_lib_context {
|
|
82
|
+
ggml_tensor * src0;
|
|
83
|
+
ggml_tensor * src1;
|
|
84
|
+
ggml_tensor * src2;
|
|
85
|
+
ggml_tensor * src3;
|
|
86
|
+
ggml_tensor * src4;
|
|
87
|
+
ggml_tensor * src5;
|
|
88
|
+
ggml_tensor * dst;
|
|
89
|
+
|
|
90
|
+
uint32_t max_wg_size;
|
|
91
|
+
size_t wg_mem_limit_bytes = 0;
|
|
92
|
+
bool supports_subgroups = false;
|
|
93
|
+
bool supports_subgroup_matrix = false;
|
|
94
|
+
uint32_t sg_mat_m = 0;
|
|
95
|
+
uint32_t sg_mat_n = 0;
|
|
96
|
+
uint32_t sg_mat_k = 0;
|
|
97
|
+
uint32_t min_subgroup_size = 0;
|
|
98
|
+
uint32_t max_subgroup_size = 0;
|
|
99
|
+
bool supports_dot_product = false;
|
|
100
|
+
std::string vendor;
|
|
101
|
+
};
|
|
102
|
+
|
|
103
|
+
struct webgpu_pipeline {
|
|
104
|
+
wgpu::ComputePipeline pipeline;
|
|
105
|
+
std::string name;
|
|
106
|
+
std::shared_ptr<void> context = nullptr;
|
|
107
|
+
};
|
|
108
|
+
|
|
109
|
+
struct ggml_webgpu_generic_shader_decisions {
|
|
110
|
+
uint32_t wg_size = 0;
|
|
111
|
+
bool inplace = false;
|
|
112
|
+
};
|
|
113
|
+
|
|
114
|
+
struct ggml_webgpu_binary_shader_decisions {
|
|
115
|
+
uint32_t wg_size = 0;
|
|
116
|
+
bool inplace = false;
|
|
117
|
+
bool overlap = false;
|
|
118
|
+
bool src_overlap = false;
|
|
119
|
+
};
|
|
120
|
+
|
|
121
|
+
struct ggml_webgpu_processed_shader {
|
|
122
|
+
std::string wgsl;
|
|
123
|
+
std::string variant;
|
|
124
|
+
std::shared_ptr<void> decisions;
|
|
125
|
+
};
|
|
126
|
+
|
|
127
|
+
struct ggml_webgpu_ssm_conv_shader_decisions {
|
|
128
|
+
uint32_t block_size;
|
|
129
|
+
uint32_t tokens_per_wg;
|
|
130
|
+
};
|
|
131
|
+
|
|
132
|
+
struct ggml_webgpu_ssm_scan_pipeline_key {
|
|
133
|
+
int type;
|
|
134
|
+
int d_state;
|
|
135
|
+
bool xbc_overlap;
|
|
136
|
+
|
|
137
|
+
bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const {
|
|
138
|
+
return type == other.type && d_state == other.d_state && xbc_overlap == other.xbc_overlap;
|
|
139
|
+
}
|
|
140
|
+
};
|
|
141
|
+
|
|
142
|
+
struct ggml_webgpu_ssm_scan_pipeline_key_hash {
|
|
143
|
+
size_t operator()(const ggml_webgpu_ssm_scan_pipeline_key & key) const {
|
|
144
|
+
size_t seed = 0;
|
|
145
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
146
|
+
ggml_webgpu_hash_combine(seed, key.d_state);
|
|
147
|
+
ggml_webgpu_hash_combine(seed, key.xbc_overlap);
|
|
148
|
+
return seed;
|
|
149
|
+
}
|
|
150
|
+
};
|
|
151
|
+
|
|
152
|
+
struct ggml_webgpu_ssm_scan_shader_decisions {
|
|
153
|
+
uint32_t wg_size;
|
|
154
|
+
uint32_t tokens_per_tile;
|
|
155
|
+
bool xbc_overlap = false;
|
|
156
|
+
};
|
|
157
|
+
|
|
158
|
+
/** Argsort **/
|
|
159
|
+
|
|
160
|
+
struct ggml_webgpu_argsort_shader_lib_context {
|
|
161
|
+
uint32_t max_wg_size;
|
|
162
|
+
size_t wg_mem_limit_bytes;
|
|
163
|
+
int32_t order;
|
|
164
|
+
};
|
|
165
|
+
|
|
166
|
+
/** Set Rows **/
|
|
167
|
+
|
|
168
|
+
struct ggml_webgpu_set_rows_pipeline_key {
|
|
169
|
+
int dst_type;
|
|
170
|
+
int vec4;
|
|
171
|
+
int i64_idx;
|
|
172
|
+
int pair_blocks;
|
|
173
|
+
|
|
174
|
+
bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
|
|
175
|
+
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx &&
|
|
176
|
+
pair_blocks == other.pair_blocks;
|
|
177
|
+
}
|
|
178
|
+
};
|
|
179
|
+
|
|
180
|
+
struct ggml_webgpu_set_rows_pipeline_key_hash {
|
|
181
|
+
size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
|
|
182
|
+
size_t seed = 0;
|
|
183
|
+
ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
184
|
+
ggml_webgpu_hash_combine(seed, key.vec4);
|
|
185
|
+
ggml_webgpu_hash_combine(seed, key.i64_idx);
|
|
186
|
+
ggml_webgpu_hash_combine(seed, key.pair_blocks);
|
|
187
|
+
return seed;
|
|
188
|
+
}
|
|
189
|
+
};
|
|
190
|
+
|
|
191
|
+
struct ggml_webgpu_set_rows_shader_decisions {
|
|
192
|
+
bool vec4;
|
|
193
|
+
bool i64_idx;
|
|
194
|
+
bool pair_blocks;
|
|
195
|
+
uint32_t wg_size;
|
|
196
|
+
};
|
|
197
|
+
|
|
198
|
+
/** Set **/
|
|
199
|
+
|
|
200
|
+
struct ggml_webgpu_set_pipeline_key {
|
|
201
|
+
ggml_type type;
|
|
202
|
+
bool inplace;
|
|
203
|
+
|
|
204
|
+
bool operator==(const ggml_webgpu_set_pipeline_key & other) const {
|
|
205
|
+
return type == other.type && inplace == other.inplace;
|
|
206
|
+
}
|
|
207
|
+
};
|
|
208
|
+
|
|
209
|
+
struct ggml_webgpu_set_pipeline_key_hash {
|
|
210
|
+
size_t operator()(const ggml_webgpu_set_pipeline_key & key) const {
|
|
211
|
+
size_t seed = 0;
|
|
212
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
213
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
214
|
+
return seed;
|
|
215
|
+
}
|
|
216
|
+
};
|
|
217
|
+
|
|
218
|
+
/** Get Rows **/
|
|
219
|
+
|
|
220
|
+
struct ggml_webgpu_get_rows_pipeline_key {
|
|
221
|
+
ggml_type src_type;
|
|
222
|
+
int vectorized;
|
|
223
|
+
|
|
224
|
+
bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const {
|
|
225
|
+
return src_type == other.src_type && vectorized == other.vectorized;
|
|
226
|
+
}
|
|
227
|
+
};
|
|
228
|
+
|
|
229
|
+
struct ggml_webgpu_get_rows_pipeline_key_hash {
|
|
230
|
+
size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const {
|
|
231
|
+
size_t seed = 0;
|
|
232
|
+
ggml_webgpu_hash_combine(seed, key.src_type);
|
|
233
|
+
ggml_webgpu_hash_combine(seed, key.vectorized);
|
|
234
|
+
return seed;
|
|
235
|
+
}
|
|
236
|
+
};
|
|
237
|
+
|
|
238
|
+
/** Row Norm **/
|
|
239
|
+
|
|
240
|
+
struct ggml_webgpu_row_norm_pipeline_key {
|
|
241
|
+
ggml_op op;
|
|
242
|
+
ggml_type src_type;
|
|
243
|
+
ggml_type dst_type;
|
|
244
|
+
bool inplace;
|
|
245
|
+
|
|
246
|
+
bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
|
|
247
|
+
return op == other.op && src_type == other.src_type && dst_type == other.dst_type && inplace == other.inplace;
|
|
248
|
+
}
|
|
249
|
+
};
|
|
250
|
+
|
|
251
|
+
struct ggml_webgpu_row_norm_pipeline_key_hash {
|
|
252
|
+
size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
|
|
253
|
+
size_t seed = 0;
|
|
254
|
+
ggml_webgpu_hash_combine(seed, key.op);
|
|
255
|
+
ggml_webgpu_hash_combine(seed, key.src_type);
|
|
256
|
+
ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
257
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
258
|
+
return seed;
|
|
259
|
+
}
|
|
260
|
+
};
|
|
261
|
+
|
|
262
|
+
/** RMS_NORM + MUL **/
|
|
263
|
+
|
|
264
|
+
struct ggml_webgpu_rms_norm_mul_pipeline_key {
|
|
265
|
+
bool inplace; // rn_src == dst
|
|
266
|
+
bool overlap; // mul_src == dst
|
|
267
|
+
bool src_overlap; // rn_src == mul_src
|
|
268
|
+
|
|
269
|
+
bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const {
|
|
270
|
+
return inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
|
|
271
|
+
}
|
|
272
|
+
};
|
|
273
|
+
|
|
274
|
+
struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
|
|
275
|
+
size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const {
|
|
276
|
+
size_t seed = 0;
|
|
277
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
278
|
+
ggml_webgpu_hash_combine(seed, key.overlap);
|
|
279
|
+
ggml_webgpu_hash_combine(seed, key.src_overlap);
|
|
280
|
+
return seed;
|
|
281
|
+
}
|
|
282
|
+
};
|
|
283
|
+
|
|
284
|
+
struct ggml_webgpu_rms_norm_mul_shader_decisions {
|
|
285
|
+
uint32_t wg_size = 0;
|
|
286
|
+
bool inplace = false;
|
|
287
|
+
bool overlap = false;
|
|
288
|
+
bool src_overlap = false;
|
|
289
|
+
};
|
|
290
|
+
|
|
291
|
+
/** Pad **/
|
|
292
|
+
struct ggml_webgpu_pad_pipeline_key {
|
|
293
|
+
bool circular;
|
|
294
|
+
|
|
295
|
+
bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
|
|
296
|
+
};
|
|
297
|
+
|
|
298
|
+
struct ggml_webgpu_pad_pipeline_key_hash {
|
|
299
|
+
size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
|
|
300
|
+
size_t seed = 0;
|
|
301
|
+
ggml_webgpu_hash_combine(seed, key.circular);
|
|
302
|
+
return seed;
|
|
303
|
+
}
|
|
304
|
+
};
|
|
305
|
+
|
|
306
|
+
/** Solve Tri **/
|
|
307
|
+
struct ggml_webgpu_solve_tri_pipeline_key {
|
|
308
|
+
int type;
|
|
309
|
+
int n;
|
|
310
|
+
int k;
|
|
311
|
+
|
|
312
|
+
bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const {
|
|
313
|
+
return type == other.type && n == other.n && k == other.k;
|
|
314
|
+
}
|
|
315
|
+
};
|
|
316
|
+
|
|
317
|
+
struct ggml_webgpu_solve_tri_pipeline_key_hash {
|
|
318
|
+
size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const {
|
|
319
|
+
size_t seed = 0;
|
|
320
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
321
|
+
ggml_webgpu_hash_combine(seed, key.n);
|
|
322
|
+
ggml_webgpu_hash_combine(seed, key.k);
|
|
323
|
+
return seed;
|
|
324
|
+
}
|
|
325
|
+
};
|
|
326
|
+
|
|
327
|
+
/** SSM Conv **/
|
|
328
|
+
struct ggml_webgpu_ssm_conv_pipeline_key {
|
|
329
|
+
int type;
|
|
330
|
+
int vectorized;
|
|
331
|
+
|
|
332
|
+
bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const {
|
|
333
|
+
return type == other.type && vectorized == other.vectorized;
|
|
334
|
+
}
|
|
335
|
+
};
|
|
336
|
+
|
|
337
|
+
/** CONV 2D */
|
|
338
|
+
struct ggml_webgpu_conv2d_pipeline_key {
|
|
339
|
+
ggml_type weight_type;
|
|
340
|
+
ggml_type input_type;
|
|
341
|
+
ggml_type output_type;
|
|
342
|
+
|
|
343
|
+
bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const {
|
|
344
|
+
return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type;
|
|
345
|
+
}
|
|
346
|
+
};
|
|
347
|
+
|
|
348
|
+
struct ggml_webgpu_conv2d_pipeline_key_hash {
|
|
349
|
+
size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const {
|
|
350
|
+
size_t seed = 0;
|
|
351
|
+
ggml_webgpu_hash_combine(seed, key.weight_type);
|
|
352
|
+
ggml_webgpu_hash_combine(seed, key.input_type);
|
|
353
|
+
ggml_webgpu_hash_combine(seed, key.output_type);
|
|
354
|
+
return seed;
|
|
355
|
+
}
|
|
356
|
+
};
|
|
357
|
+
|
|
358
|
+
/** Im2Col **/
|
|
359
|
+
struct ggml_webgpu_im2col_pipeline_key {
|
|
360
|
+
ggml_type input_type;
|
|
361
|
+
ggml_type output_type;
|
|
362
|
+
|
|
363
|
+
bool operator==(const ggml_webgpu_im2col_pipeline_key & other) const {
|
|
364
|
+
return input_type == other.input_type && output_type == other.output_type;
|
|
365
|
+
}
|
|
366
|
+
};
|
|
367
|
+
|
|
368
|
+
struct ggml_webgpu_im2col_pipeline_key_hash {
|
|
369
|
+
size_t operator()(const ggml_webgpu_im2col_pipeline_key & key) const {
|
|
370
|
+
size_t seed = 0;
|
|
371
|
+
ggml_webgpu_hash_combine(seed, key.input_type);
|
|
372
|
+
ggml_webgpu_hash_combine(seed, key.output_type);
|
|
373
|
+
return seed;
|
|
374
|
+
}
|
|
375
|
+
};
|
|
376
|
+
|
|
377
|
+
/** Gated Delta Net **/
|
|
378
|
+
struct ggml_webgpu_gated_delta_net_pipeline_key {
|
|
379
|
+
int type;
|
|
380
|
+
int s_v;
|
|
381
|
+
int kda;
|
|
382
|
+
|
|
383
|
+
bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const {
|
|
384
|
+
return type == other.type && s_v == other.s_v && kda == other.kda;
|
|
385
|
+
}
|
|
386
|
+
};
|
|
387
|
+
|
|
388
|
+
struct ggml_webgpu_gated_delta_net_pipeline_key_hash {
|
|
389
|
+
size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const {
|
|
390
|
+
size_t seed = 0;
|
|
391
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
392
|
+
ggml_webgpu_hash_combine(seed, key.s_v);
|
|
393
|
+
ggml_webgpu_hash_combine(seed, key.kda);
|
|
394
|
+
return seed;
|
|
395
|
+
}
|
|
396
|
+
};
|
|
397
|
+
|
|
398
|
+
struct ggml_webgpu_ssm_conv_pipeline_key_hash {
|
|
399
|
+
size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const {
|
|
400
|
+
size_t seed = 0;
|
|
401
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
402
|
+
ggml_webgpu_hash_combine(seed, key.vectorized);
|
|
403
|
+
return seed;
|
|
404
|
+
}
|
|
405
|
+
};
|
|
406
|
+
|
|
407
|
+
/** Scale **/
|
|
408
|
+
|
|
409
|
+
struct ggml_webgpu_scale_pipeline_key {
|
|
410
|
+
int inplace;
|
|
411
|
+
|
|
412
|
+
bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; }
|
|
413
|
+
};
|
|
414
|
+
|
|
415
|
+
struct ggml_webgpu_scale_pipeline_key_hash {
|
|
416
|
+
size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const {
|
|
417
|
+
size_t seed = 0;
|
|
418
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
419
|
+
return seed;
|
|
420
|
+
}
|
|
421
|
+
};
|
|
422
|
+
|
|
423
|
+
/** Upscale **/
|
|
424
|
+
|
|
425
|
+
struct ggml_webgpu_upscale_pipeline_key {
|
|
426
|
+
ggml_type input_type;
|
|
427
|
+
ggml_type output_type;
|
|
428
|
+
uint32_t base_mode;
|
|
429
|
+
bool antialias;
|
|
430
|
+
|
|
431
|
+
bool operator==(const ggml_webgpu_upscale_pipeline_key & other) const {
|
|
432
|
+
return input_type == other.input_type && output_type == other.output_type && base_mode == other.base_mode &&
|
|
433
|
+
antialias == other.antialias;
|
|
434
|
+
}
|
|
435
|
+
};
|
|
436
|
+
|
|
437
|
+
struct ggml_webgpu_upscale_pipeline_key_hash {
|
|
438
|
+
size_t operator()(const ggml_webgpu_upscale_pipeline_key & key) const {
|
|
439
|
+
size_t seed = 0;
|
|
440
|
+
ggml_webgpu_hash_combine(seed, key.input_type);
|
|
441
|
+
ggml_webgpu_hash_combine(seed, key.output_type);
|
|
442
|
+
ggml_webgpu_hash_combine(seed, key.base_mode);
|
|
443
|
+
ggml_webgpu_hash_combine(seed, key.antialias);
|
|
444
|
+
return seed;
|
|
445
|
+
}
|
|
446
|
+
};
|
|
447
|
+
|
|
448
|
+
/** Concat **/
|
|
449
|
+
|
|
450
|
+
struct ggml_webgpu_concat_pipeline_key {
|
|
451
|
+
int type;
|
|
452
|
+
bool src_overlap;
|
|
453
|
+
|
|
454
|
+
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const {
|
|
455
|
+
return type == other.type && src_overlap == other.src_overlap;
|
|
456
|
+
}
|
|
457
|
+
};
|
|
458
|
+
|
|
459
|
+
struct ggml_webgpu_concat_pipeline_key_hash {
|
|
460
|
+
size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
|
|
461
|
+
size_t seed = 0;
|
|
462
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
463
|
+
ggml_webgpu_hash_combine(seed, key.src_overlap);
|
|
464
|
+
return seed;
|
|
465
|
+
}
|
|
466
|
+
};
|
|
467
|
+
|
|
468
|
+
/** Repeat **/
|
|
469
|
+
|
|
470
|
+
struct ggml_webgpu_repeat_pipeline_key {
|
|
471
|
+
int type;
|
|
472
|
+
|
|
473
|
+
bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; }
|
|
474
|
+
};
|
|
475
|
+
|
|
476
|
+
struct ggml_webgpu_repeat_pipeline_key_hash {
|
|
477
|
+
size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const {
|
|
478
|
+
size_t seed = 0;
|
|
479
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
480
|
+
return seed;
|
|
481
|
+
}
|
|
482
|
+
};
|
|
483
|
+
|
|
484
|
+
/** Binary **/
|
|
485
|
+
|
|
486
|
+
struct ggml_webgpu_binary_pipeline_key {
|
|
487
|
+
int type;
|
|
488
|
+
int op;
|
|
489
|
+
bool inplace;
|
|
490
|
+
bool overlap;
|
|
491
|
+
bool src_overlap;
|
|
492
|
+
|
|
493
|
+
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
|
|
494
|
+
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
|
|
495
|
+
src_overlap == other.src_overlap;
|
|
496
|
+
}
|
|
497
|
+
};
|
|
498
|
+
|
|
499
|
+
struct ggml_webgpu_binary_pipeline_key_hash {
|
|
500
|
+
size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
|
|
501
|
+
size_t seed = 0;
|
|
502
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
503
|
+
ggml_webgpu_hash_combine(seed, key.op);
|
|
504
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
505
|
+
ggml_webgpu_hash_combine(seed, key.overlap);
|
|
506
|
+
ggml_webgpu_hash_combine(seed, key.src_overlap);
|
|
507
|
+
return seed;
|
|
508
|
+
}
|
|
509
|
+
};
|
|
510
|
+
|
|
511
|
+
/* Add_Id */
|
|
512
|
+
|
|
513
|
+
struct ggml_webgpu_add_id_pipeline_key {
|
|
514
|
+
bool inplace;
|
|
515
|
+
|
|
516
|
+
bool operator==(const ggml_webgpu_add_id_pipeline_key & other) const { return inplace == other.inplace; }
|
|
517
|
+
};
|
|
518
|
+
|
|
519
|
+
struct ggml_webgpu_add_id_pipeline_key_hash {
|
|
520
|
+
size_t operator()(const ggml_webgpu_add_id_pipeline_key & key) const {
|
|
521
|
+
size_t seed = 0;
|
|
522
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
523
|
+
return seed;
|
|
524
|
+
}
|
|
525
|
+
};
|
|
526
|
+
|
|
527
|
+
/** Unary **/
|
|
528
|
+
|
|
529
|
+
struct ggml_webgpu_unary_pipeline_key {
|
|
530
|
+
int type;
|
|
531
|
+
int op;
|
|
532
|
+
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
|
|
533
|
+
bool inplace;
|
|
534
|
+
ggml_tri_type ttype; // only used for GGML_OP_TRI
|
|
535
|
+
|
|
536
|
+
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
|
|
537
|
+
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace &&
|
|
538
|
+
ttype == other.ttype;
|
|
539
|
+
}
|
|
540
|
+
};
|
|
541
|
+
|
|
542
|
+
struct ggml_webgpu_unary_pipeline_key_hash {
|
|
543
|
+
size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
|
|
544
|
+
size_t seed = 0;
|
|
545
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
546
|
+
ggml_webgpu_hash_combine(seed, key.op);
|
|
547
|
+
ggml_webgpu_hash_combine(seed, key.is_unary);
|
|
548
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
549
|
+
ggml_webgpu_hash_combine(seed, key.ttype);
|
|
550
|
+
return seed;
|
|
551
|
+
}
|
|
552
|
+
};
|
|
553
|
+
|
|
554
|
+
/** FlashAttention */
|
|
555
|
+
|
|
556
|
+
struct ggml_webgpu_flash_attn_common_pipeline_key {
|
|
557
|
+
ggml_type q_type;
|
|
558
|
+
ggml_type k_type;
|
|
559
|
+
ggml_type v_type;
|
|
560
|
+
ggml_type dst_type;
|
|
19
561
|
uint32_t head_dim_qk;
|
|
20
562
|
uint32_t head_dim_v;
|
|
21
563
|
bool kv_direct;
|
|
564
|
+
bool kv_overlap;
|
|
22
565
|
bool has_mask;
|
|
23
566
|
bool has_sinks;
|
|
24
567
|
bool uses_logit_softcap;
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
568
|
+
|
|
569
|
+
bool operator==(const ggml_webgpu_flash_attn_common_pipeline_key & other) const {
|
|
570
|
+
return q_type == other.q_type && k_type == other.k_type && v_type == other.v_type &&
|
|
571
|
+
dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
|
572
|
+
kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask &&
|
|
573
|
+
has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap;
|
|
574
|
+
}
|
|
30
575
|
};
|
|
31
576
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
577
|
+
inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed,
|
|
578
|
+
const ggml_webgpu_flash_attn_common_pipeline_key & key) {
|
|
579
|
+
ggml_webgpu_hash_combine(seed, key.q_type);
|
|
580
|
+
ggml_webgpu_hash_combine(seed, key.k_type);
|
|
581
|
+
ggml_webgpu_hash_combine(seed, key.v_type);
|
|
582
|
+
ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
583
|
+
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
|
584
|
+
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
|
585
|
+
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
|
586
|
+
ggml_webgpu_hash_combine(seed, key.kv_overlap);
|
|
587
|
+
ggml_webgpu_hash_combine(seed, key.has_mask);
|
|
588
|
+
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
|
589
|
+
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
struct ggml_webgpu_flash_attn_vec_pipeline_key {
|
|
593
|
+
ggml_webgpu_flash_attn_common_pipeline_key common;
|
|
594
|
+
|
|
595
|
+
bool operator==(const ggml_webgpu_flash_attn_vec_pipeline_key & other) const { return common == other.common; }
|
|
36
596
|
};
|
|
37
597
|
|
|
38
|
-
struct
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
598
|
+
struct ggml_webgpu_flash_attn_vec_pipeline_key_hash {
|
|
599
|
+
size_t operator()(const ggml_webgpu_flash_attn_vec_pipeline_key & key) const {
|
|
600
|
+
size_t seed = 0;
|
|
601
|
+
ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common);
|
|
602
|
+
return seed;
|
|
603
|
+
}
|
|
42
604
|
};
|
|
43
605
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
bool kv_direct) {
|
|
51
|
-
const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
|
|
52
|
-
size_t f16_elems = 0;
|
|
53
|
-
size_t f32_elems = 0;
|
|
54
|
-
f16_elems += q_tile * head_dim_qk; // q_shmem
|
|
55
|
-
if (!kv_direct) {
|
|
56
|
-
f16_elems += kv_tile * max_head_dim; // kv_shmem
|
|
606
|
+
struct ggml_webgpu_flash_attn_pipeline_key {
|
|
607
|
+
ggml_webgpu_flash_attn_common_pipeline_key common;
|
|
608
|
+
bool use_sg_matrix;
|
|
609
|
+
|
|
610
|
+
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
|
611
|
+
return common == other.common && use_sg_matrix == other.use_sg_matrix;
|
|
57
612
|
}
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
613
|
+
};
|
|
614
|
+
|
|
615
|
+
struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
|
616
|
+
size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
|
|
617
|
+
size_t seed = 0;
|
|
618
|
+
ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common);
|
|
619
|
+
ggml_webgpu_hash_combine(seed, key.use_sg_matrix);
|
|
620
|
+
return seed;
|
|
61
621
|
}
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
622
|
+
};
|
|
623
|
+
|
|
624
|
+
struct ggml_webgpu_flash_attn_vec_decisions {
|
|
625
|
+
uint32_t kv_tile = 0;
|
|
626
|
+
uint32_t wg_size = 0;
|
|
627
|
+
};
|
|
628
|
+
|
|
629
|
+
struct ggml_webgpu_flash_attn_decisions {
|
|
630
|
+
bool use_sg_matrix = false;
|
|
631
|
+
uint32_t q_tile = 0;
|
|
632
|
+
uint32_t kv_tile = 0;
|
|
633
|
+
uint32_t wg_size = 0;
|
|
634
|
+
};
|
|
635
|
+
|
|
636
|
+
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
|
|
637
|
+
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u;
|
|
638
|
+
|
|
639
|
+
inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) {
|
|
640
|
+
constexpr uintptr_t ptr_base_addr = 0x1000u;
|
|
641
|
+
const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor;
|
|
642
|
+
return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs;
|
|
66
643
|
}
|
|
67
644
|
|
|
68
|
-
|
|
69
|
-
const
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
645
|
+
inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) {
|
|
646
|
+
const uint32_t offset_elems =
|
|
647
|
+
(uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) /
|
|
648
|
+
ggml_type_size(K->type));
|
|
649
|
+
return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u;
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K,
|
|
653
|
+
const ggml_tensor * V,
|
|
654
|
+
size_t storage_offset_alignment) {
|
|
655
|
+
return ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment) &&
|
|
656
|
+
ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment);
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q,
|
|
660
|
+
const ggml_tensor * K,
|
|
661
|
+
const ggml_tensor * V,
|
|
662
|
+
uint32_t kv_direct_align) {
|
|
663
|
+
return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) &&
|
|
664
|
+
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_common_pipeline_key(
|
|
668
|
+
const ggml_webgpu_shader_lib_context & context,
|
|
669
|
+
uint32_t kv_direct_align) {
|
|
670
|
+
ggml_webgpu_flash_attn_common_pipeline_key key = {};
|
|
671
|
+
key.q_type = context.src0->type;
|
|
672
|
+
key.k_type = context.src1->type;
|
|
673
|
+
key.v_type = context.src2->type;
|
|
674
|
+
key.dst_type = context.dst->type;
|
|
675
|
+
key.head_dim_qk = (uint32_t) context.src0->ne[0];
|
|
676
|
+
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
|
677
|
+
key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
|
|
678
|
+
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
|
|
679
|
+
key.has_mask = context.src3 != nullptr;
|
|
680
|
+
key.has_sinks = context.src4 != nullptr;
|
|
681
|
+
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
|
|
682
|
+
return key;
|
|
84
683
|
}
|
|
85
684
|
|
|
86
|
-
inline
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
685
|
+
inline std::vector<std::string> ggml_webgpu_flash_attn_common_defines(
|
|
686
|
+
const ggml_webgpu_flash_attn_common_pipeline_key & key,
|
|
687
|
+
std::string & variant,
|
|
688
|
+
uint32_t q_tile,
|
|
689
|
+
uint32_t kv_tile,
|
|
690
|
+
uint32_t wg_size) {
|
|
90
691
|
std::vector<std::string> defines;
|
|
91
|
-
std::string variant = "flash_attn";
|
|
92
692
|
|
|
93
|
-
switch (
|
|
693
|
+
switch (key.k_type) {
|
|
694
|
+
case GGML_TYPE_F32:
|
|
695
|
+
defines.push_back("K_F32");
|
|
696
|
+
break;
|
|
697
|
+
case GGML_TYPE_F16:
|
|
698
|
+
defines.push_back("K_F16");
|
|
699
|
+
break;
|
|
700
|
+
case GGML_TYPE_Q4_0:
|
|
701
|
+
defines.push_back("K_Q4_0");
|
|
702
|
+
break;
|
|
703
|
+
case GGML_TYPE_Q8_0:
|
|
704
|
+
defines.push_back("K_Q8_0");
|
|
705
|
+
break;
|
|
706
|
+
default:
|
|
707
|
+
GGML_ABORT("Unsupported K type for flash attention shader");
|
|
708
|
+
}
|
|
709
|
+
variant += std::string("_k") + ggml_type_name(key.k_type);
|
|
710
|
+
|
|
711
|
+
switch (key.v_type) {
|
|
94
712
|
case GGML_TYPE_F32:
|
|
95
|
-
defines.push_back("
|
|
713
|
+
defines.push_back("V_F32");
|
|
96
714
|
break;
|
|
97
715
|
case GGML_TYPE_F16:
|
|
98
|
-
defines.push_back("
|
|
716
|
+
defines.push_back("V_F16");
|
|
99
717
|
break;
|
|
100
718
|
case GGML_TYPE_Q4_0:
|
|
101
|
-
defines.push_back("
|
|
719
|
+
defines.push_back("V_Q4_0");
|
|
102
720
|
break;
|
|
103
721
|
case GGML_TYPE_Q8_0:
|
|
104
|
-
defines.push_back("
|
|
722
|
+
defines.push_back("V_Q8_0");
|
|
723
|
+
break;
|
|
724
|
+
default:
|
|
725
|
+
GGML_ABORT("Unsupported V type for flash attention shader");
|
|
726
|
+
}
|
|
727
|
+
variant += std::string("_v") + ggml_type_name(key.v_type);
|
|
728
|
+
|
|
729
|
+
switch (key.q_type) {
|
|
730
|
+
case GGML_TYPE_F32:
|
|
731
|
+
defines.push_back("Q_F32");
|
|
732
|
+
break;
|
|
733
|
+
case GGML_TYPE_F16:
|
|
734
|
+
defines.push_back("Q_F16");
|
|
735
|
+
break;
|
|
736
|
+
default:
|
|
737
|
+
GGML_ABORT("Unsupported Q type for flash attention shader");
|
|
738
|
+
}
|
|
739
|
+
variant += std::string("_q") + ggml_type_name(key.q_type);
|
|
740
|
+
|
|
741
|
+
switch (key.dst_type) {
|
|
742
|
+
case GGML_TYPE_F32:
|
|
743
|
+
defines.push_back("DST_F32");
|
|
744
|
+
break;
|
|
745
|
+
case GGML_TYPE_F16:
|
|
746
|
+
defines.push_back("DST_F16");
|
|
105
747
|
break;
|
|
106
748
|
default:
|
|
107
|
-
GGML_ABORT("Unsupported
|
|
749
|
+
GGML_ABORT("Unsupported dst type for flash attention shader");
|
|
108
750
|
}
|
|
109
|
-
variant += std::string("
|
|
751
|
+
variant += std::string("_dst") + ggml_type_name(key.dst_type);
|
|
110
752
|
|
|
111
|
-
if (
|
|
753
|
+
if (key.has_mask) {
|
|
112
754
|
defines.push_back("MASK");
|
|
113
755
|
variant += "_mask";
|
|
114
756
|
}
|
|
115
|
-
if (
|
|
757
|
+
if (key.has_sinks) {
|
|
116
758
|
defines.push_back("SINKS");
|
|
117
759
|
variant += "_sinks";
|
|
118
760
|
}
|
|
119
|
-
if (
|
|
761
|
+
if (key.uses_logit_softcap) {
|
|
120
762
|
defines.push_back("LOGIT_SOFTCAP");
|
|
121
763
|
variant += "_lgsc";
|
|
122
764
|
}
|
|
123
|
-
|
|
124
|
-
if (context.kv_direct) {
|
|
765
|
+
if (key.kv_direct) {
|
|
125
766
|
defines.push_back("KV_DIRECT");
|
|
126
767
|
variant += "_kvdirect";
|
|
127
768
|
}
|
|
769
|
+
if (key.kv_overlap) {
|
|
770
|
+
defines.push_back("KV_OVERLAP");
|
|
771
|
+
variant += "_kv_overlap";
|
|
772
|
+
}
|
|
128
773
|
|
|
129
|
-
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(
|
|
130
|
-
variant += std::string("_hsqk") + std::to_string(
|
|
774
|
+
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
|
775
|
+
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
|
131
776
|
|
|
132
|
-
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(
|
|
133
|
-
variant += std::string("_hsv") + std::to_string(
|
|
777
|
+
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
|
778
|
+
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
|
134
779
|
|
|
135
|
-
|
|
136
|
-
defines.push_back(std::string("
|
|
137
|
-
defines.push_back(std::string("
|
|
138
|
-
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
|
780
|
+
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
|
|
781
|
+
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
|
|
782
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
139
783
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
784
|
+
if (ggml_is_quantized(key.k_type) || ggml_is_quantized(key.v_type)) {
|
|
785
|
+
defines.push_back("U32_DEQUANT_HELPERS");
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
return defines;
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
|
|
792
|
+
uint32_t head_dim_v;
|
|
793
|
+
uint32_t wg_size;
|
|
794
|
+
ggml_type dst_type;
|
|
795
|
+
};
|
|
796
|
+
|
|
797
|
+
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
|
|
798
|
+
size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const {
|
|
799
|
+
size_t seed = 0;
|
|
800
|
+
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
|
801
|
+
ggml_webgpu_hash_combine(seed, key.wg_size);
|
|
802
|
+
ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
803
|
+
return seed;
|
|
804
|
+
}
|
|
805
|
+
};
|
|
806
|
+
|
|
807
|
+
inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
|
|
808
|
+
const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
|
|
809
|
+
return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size && lhs.dst_type == rhs.dst_type;
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
struct ggml_webgpu_flash_attn_blk_pipeline_key {
|
|
813
|
+
uint32_t kv_tile;
|
|
814
|
+
|
|
815
|
+
bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return kv_tile == other.kv_tile; }
|
|
816
|
+
};
|
|
817
|
+
|
|
818
|
+
struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
|
|
819
|
+
size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
|
|
820
|
+
size_t seed = 0;
|
|
821
|
+
ggml_webgpu_hash_combine(seed, key.kv_tile);
|
|
822
|
+
return seed;
|
|
823
|
+
}
|
|
824
|
+
};
|
|
825
|
+
|
|
826
|
+
// Note: this will slightly overestimate memory usage for vec path
|
|
827
|
+
// since row_max and exp_sum shmem are not needed.
|
|
828
|
+
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
|
829
|
+
uint32_t kv_tile,
|
|
830
|
+
uint32_t head_dim_qk,
|
|
831
|
+
uint32_t head_dim_v,
|
|
832
|
+
bool has_mask,
|
|
833
|
+
bool kv_direct) {
|
|
834
|
+
const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
|
|
835
|
+
size_t f16_elems = 0;
|
|
836
|
+
size_t f32_elems = 0;
|
|
837
|
+
|
|
838
|
+
f32_elems += q_tile * head_dim_qk; // q_shmem
|
|
839
|
+
if (!kv_direct) {
|
|
840
|
+
f32_elems += kv_tile * max_head_dim; // kv_shmem
|
|
841
|
+
}
|
|
842
|
+
f32_elems += q_tile * head_dim_v; // o_shmem
|
|
843
|
+
if (has_mask) {
|
|
844
|
+
f32_elems += q_tile * kv_tile; // mask_shmem
|
|
845
|
+
}
|
|
846
|
+
f32_elems += q_tile * kv_tile; // inter_shmem
|
|
847
|
+
f32_elems += q_tile; // row_max_shmem
|
|
848
|
+
f32_elems += q_tile; // exp_sum_shmem
|
|
849
|
+
return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
|
|
850
|
+
}
|
|
851
|
+
|
|
852
|
+
inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes,
|
|
853
|
+
uint32_t q_tile,
|
|
854
|
+
uint32_t kv_granularity,
|
|
855
|
+
uint32_t head_dim_qk,
|
|
856
|
+
uint32_t head_dim_v,
|
|
857
|
+
bool has_mask,
|
|
858
|
+
bool kv_direct) {
|
|
859
|
+
const size_t base_q_bytes =
|
|
860
|
+
ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, head_dim_qk, head_dim_v, has_mask, kv_direct);
|
|
861
|
+
if (limit_bytes <= base_q_bytes) {
|
|
862
|
+
return 0;
|
|
863
|
+
}
|
|
864
|
+
const size_t one_kv_bytes =
|
|
865
|
+
ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, head_dim_qk, head_dim_v, has_mask, kv_direct);
|
|
866
|
+
const size_t bytes_per_kv = one_kv_bytes - base_q_bytes;
|
|
867
|
+
if (bytes_per_kv == 0) {
|
|
868
|
+
return 0;
|
|
869
|
+
}
|
|
870
|
+
const size_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
|
871
|
+
return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity);
|
|
872
|
+
}
|
|
873
|
+
|
|
874
|
+
inline uint32_t ggml_webgpu_flash_attn_get_vec_kv_tile(size_t wg_mem_limit_bytes,
|
|
875
|
+
uint32_t head_dim_qk,
|
|
876
|
+
uint32_t head_dim_v,
|
|
877
|
+
bool has_mask,
|
|
878
|
+
bool kv_direct) {
|
|
879
|
+
const uint32_t max_kv_tile =
|
|
880
|
+
ggml_webgpu_flash_attn_max_kv_tile(wg_mem_limit_bytes, 1u, 1u, head_dim_qk, head_dim_v, has_mask, kv_direct);
|
|
881
|
+
GGML_ASSERT(max_kv_tile > 0);
|
|
882
|
+
|
|
883
|
+
uint32_t kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile);
|
|
884
|
+
if (kv_direct) {
|
|
885
|
+
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
|
886
|
+
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
|
887
|
+
kv_tile -= 1u;
|
|
149
888
|
}
|
|
150
889
|
}
|
|
151
890
|
|
|
152
|
-
|
|
153
|
-
|
|
891
|
+
return kv_tile;
|
|
892
|
+
}
|
|
154
893
|
|
|
155
|
-
|
|
156
|
-
|
|
894
|
+
inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix,
|
|
895
|
+
uint32_t sg_mat_k,
|
|
896
|
+
uint32_t sg_mat_n,
|
|
897
|
+
const ggml_tensor * Q,
|
|
898
|
+
const ggml_tensor * V) {
|
|
899
|
+
return supports_subgroup_matrix && Q->ne[0] % sg_mat_k == 0 && V->ne[0] % sg_mat_n == 0;
|
|
900
|
+
}
|
|
157
901
|
|
|
158
|
-
|
|
902
|
+
/** Matrix Multiplication **/
|
|
903
|
+
|
|
904
|
+
struct ggml_webgpu_mul_mat_vec_pipeline_key {
|
|
905
|
+
ggml_type src0_type;
|
|
906
|
+
ggml_type src1_type;
|
|
907
|
+
int vectorized;
|
|
908
|
+
bool use_mmvq;
|
|
909
|
+
|
|
910
|
+
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
|
|
911
|
+
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
|
|
912
|
+
use_mmvq == other.use_mmvq;
|
|
913
|
+
}
|
|
914
|
+
};
|
|
915
|
+
|
|
916
|
+
struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
|
|
917
|
+
size_t operator()(const ggml_webgpu_mul_mat_vec_pipeline_key & key) const {
|
|
918
|
+
size_t seed = 0;
|
|
919
|
+
ggml_webgpu_hash_combine(seed, key.src0_type);
|
|
920
|
+
ggml_webgpu_hash_combine(seed, key.src1_type);
|
|
921
|
+
ggml_webgpu_hash_combine(seed, key.vectorized);
|
|
922
|
+
ggml_webgpu_hash_combine(seed, key.use_mmvq);
|
|
923
|
+
return seed;
|
|
924
|
+
}
|
|
925
|
+
};
|
|
926
|
+
|
|
927
|
+
struct ggml_webgpu_mul_mat_vec_shader_decisions {
|
|
928
|
+
uint32_t wg_size;
|
|
929
|
+
uint32_t outputs_per_wg;
|
|
930
|
+
uint32_t vec_size;
|
|
931
|
+
};
|
|
932
|
+
|
|
933
|
+
struct ggml_webgpu_quantize_q8_pipeline_key {
|
|
934
|
+
ggml_type src0_type;
|
|
935
|
+
|
|
936
|
+
bool operator==(const ggml_webgpu_quantize_q8_pipeline_key & other) const { return src0_type == other.src0_type; }
|
|
937
|
+
};
|
|
938
|
+
|
|
939
|
+
struct ggml_webgpu_quantize_q8_pipeline_key_hash {
|
|
940
|
+
size_t operator()(const ggml_webgpu_quantize_q8_pipeline_key & key) const {
|
|
941
|
+
size_t seed = 0;
|
|
942
|
+
ggml_webgpu_hash_combine(seed, key.src0_type);
|
|
943
|
+
return seed;
|
|
944
|
+
}
|
|
945
|
+
};
|
|
946
|
+
|
|
947
|
+
struct ggml_webgpu_mul_mat_pipeline_key {
|
|
948
|
+
ggml_type src0_type;
|
|
949
|
+
ggml_type src1_type;
|
|
950
|
+
int vectorized;
|
|
951
|
+
int use_subgroup_matrix;
|
|
159
952
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
953
|
+
bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const {
|
|
954
|
+
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
|
|
955
|
+
use_subgroup_matrix == other.use_subgroup_matrix;
|
|
956
|
+
}
|
|
957
|
+
};
|
|
958
|
+
|
|
959
|
+
struct ggml_webgpu_mul_mat_pipeline_key_hash {
|
|
960
|
+
size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const {
|
|
961
|
+
size_t seed = 0;
|
|
962
|
+
ggml_webgpu_hash_combine(seed, key.src0_type);
|
|
963
|
+
ggml_webgpu_hash_combine(seed, key.src1_type);
|
|
964
|
+
ggml_webgpu_hash_combine(seed, key.vectorized);
|
|
965
|
+
ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix);
|
|
966
|
+
return seed;
|
|
967
|
+
}
|
|
968
|
+
};
|
|
969
|
+
|
|
970
|
+
struct ggml_webgpu_mul_mat_shader_decisions {
|
|
971
|
+
uint32_t tile_k;
|
|
972
|
+
uint32_t wg_size_m;
|
|
973
|
+
uint32_t wg_size_n;
|
|
974
|
+
uint32_t wg_size;
|
|
975
|
+
uint32_t outputs_per_wg;
|
|
976
|
+
int use_subgroup_matrix;
|
|
977
|
+
|
|
978
|
+
uint32_t tile_m;
|
|
979
|
+
uint32_t tile_n;
|
|
980
|
+
|
|
981
|
+
// Subgroup matrix parameters
|
|
982
|
+
uint32_t subgroup_m;
|
|
983
|
+
uint32_t subgroup_n;
|
|
984
|
+
uint32_t subgroup_matrix_m;
|
|
985
|
+
uint32_t subgroup_matrix_n;
|
|
986
|
+
|
|
987
|
+
uint32_t mul_mat_wg_size;
|
|
988
|
+
};
|
|
989
|
+
|
|
990
|
+
/** MUL_MAT_ID **/
|
|
991
|
+
|
|
992
|
+
struct ggml_webgpu_mul_mat_id_pipeline_key {
|
|
993
|
+
ggml_type src0_type;
|
|
994
|
+
ggml_type src1_type;
|
|
995
|
+
uint32_t n_experts;
|
|
996
|
+
int vectorized;
|
|
997
|
+
|
|
998
|
+
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
|
|
999
|
+
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
|
|
1000
|
+
vectorized == other.vectorized;
|
|
1001
|
+
}
|
|
1002
|
+
};
|
|
1003
|
+
|
|
1004
|
+
struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
|
|
1005
|
+
size_t operator()(const ggml_webgpu_mul_mat_id_pipeline_key & key) const {
|
|
1006
|
+
size_t seed = 0;
|
|
1007
|
+
ggml_webgpu_hash_combine(seed, key.src0_type);
|
|
1008
|
+
ggml_webgpu_hash_combine(seed, key.src1_type);
|
|
1009
|
+
ggml_webgpu_hash_combine(seed, key.n_experts);
|
|
1010
|
+
ggml_webgpu_hash_combine(seed, key.vectorized);
|
|
1011
|
+
return seed;
|
|
1012
|
+
}
|
|
1013
|
+
};
|
|
1014
|
+
|
|
1015
|
+
/** Cpy **/
|
|
1016
|
+
|
|
1017
|
+
struct ggml_webgpu_cpy_pipeline_key {
|
|
1018
|
+
ggml_type src_type;
|
|
1019
|
+
ggml_type dst_type;
|
|
1020
|
+
|
|
1021
|
+
bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const {
|
|
1022
|
+
return src_type == other.src_type && dst_type == other.dst_type;
|
|
1023
|
+
}
|
|
1024
|
+
};
|
|
1025
|
+
|
|
1026
|
+
struct ggml_webgpu_cpy_pipeline_key_hash {
|
|
1027
|
+
size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const {
|
|
1028
|
+
size_t seed = 0;
|
|
1029
|
+
ggml_webgpu_hash_combine(seed, key.src_type);
|
|
1030
|
+
ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
1031
|
+
return seed;
|
|
1032
|
+
}
|
|
1033
|
+
};
|
|
1034
|
+
|
|
1035
|
+
/** Glu **/
|
|
1036
|
+
|
|
1037
|
+
struct ggml_webgpu_glu_pipeline_key {
|
|
1038
|
+
ggml_glu_op glu_op;
|
|
1039
|
+
ggml_type type;
|
|
1040
|
+
bool split;
|
|
1041
|
+
|
|
1042
|
+
bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
|
|
1043
|
+
return glu_op == other.glu_op && type == other.type && split == other.split;
|
|
1044
|
+
}
|
|
1045
|
+
};
|
|
1046
|
+
|
|
1047
|
+
struct ggml_webgpu_glu_pipeline_key_hash {
|
|
1048
|
+
size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const {
|
|
1049
|
+
size_t seed = 0;
|
|
1050
|
+
ggml_webgpu_hash_combine(seed, key.glu_op);
|
|
1051
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
1052
|
+
ggml_webgpu_hash_combine(seed, key.split);
|
|
1053
|
+
return seed;
|
|
1054
|
+
}
|
|
1055
|
+
};
|
|
1056
|
+
|
|
1057
|
+
/** Rope **/
|
|
1058
|
+
|
|
1059
|
+
struct ggml_webgpu_rope_pipeline_key {
|
|
1060
|
+
ggml_type type;
|
|
1061
|
+
bool inplace;
|
|
1062
|
+
bool has_ff;
|
|
1063
|
+
|
|
1064
|
+
bool operator==(const ggml_webgpu_rope_pipeline_key & other) const {
|
|
1065
|
+
return type == other.type && inplace == other.inplace && has_ff == other.has_ff;
|
|
1066
|
+
}
|
|
1067
|
+
};
|
|
1068
|
+
|
|
1069
|
+
struct ggml_webgpu_rope_pipeline_key_hash {
|
|
1070
|
+
size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const {
|
|
1071
|
+
size_t seed = 0;
|
|
1072
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
1073
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
1074
|
+
ggml_webgpu_hash_combine(seed, key.has_ff);
|
|
1075
|
+
return seed;
|
|
1076
|
+
}
|
|
1077
|
+
};
|
|
1078
|
+
|
|
1079
|
+
/** SoftMax **/
|
|
1080
|
+
|
|
1081
|
+
struct ggml_webgpu_soft_max_pipeline_key {
|
|
1082
|
+
ggml_type mask_type;
|
|
1083
|
+
bool has_mask;
|
|
1084
|
+
bool has_sink;
|
|
1085
|
+
bool inplace;
|
|
1086
|
+
|
|
1087
|
+
bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const {
|
|
1088
|
+
return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink &&
|
|
1089
|
+
inplace == other.inplace;
|
|
1090
|
+
}
|
|
1091
|
+
};
|
|
1092
|
+
|
|
1093
|
+
struct ggml_webgpu_soft_max_pipeline_key_hash {
|
|
1094
|
+
size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const {
|
|
1095
|
+
size_t seed = 0;
|
|
1096
|
+
ggml_webgpu_hash_combine(seed, key.mask_type);
|
|
1097
|
+
ggml_webgpu_hash_combine(seed, key.has_mask);
|
|
1098
|
+
ggml_webgpu_hash_combine(seed, key.has_sink);
|
|
1099
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
1100
|
+
return seed;
|
|
1101
|
+
}
|
|
1102
|
+
};
|
|
1103
|
+
|
|
1104
|
+
/** MMVQ **/
|
|
1105
|
+
|
|
1106
|
+
inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
|
|
1107
|
+
const ggml_tensor * src1,
|
|
1108
|
+
bool supports_dot_product,
|
|
1109
|
+
const std::string & vendor) {
|
|
1110
|
+
if (src1->ne[1] == 1) {
|
|
1111
|
+
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
|
|
1112
|
+
if (supports_dp4a && supports_dot_product) {
|
|
1113
|
+
switch (src1->type) {
|
|
1114
|
+
case GGML_TYPE_F32:
|
|
1115
|
+
switch (src0->type) {
|
|
1116
|
+
case GGML_TYPE_Q4_0:
|
|
1117
|
+
case GGML_TYPE_Q4_1:
|
|
1118
|
+
case GGML_TYPE_Q8_0:
|
|
1119
|
+
case GGML_TYPE_Q2_K:
|
|
1120
|
+
case GGML_TYPE_Q4_K:
|
|
1121
|
+
return src0->ne[0] % 4 == 0;
|
|
1122
|
+
default:
|
|
1123
|
+
break;
|
|
1124
|
+
}
|
|
1125
|
+
break;
|
|
1126
|
+
default:
|
|
1127
|
+
break;
|
|
1128
|
+
}
|
|
1129
|
+
}
|
|
1130
|
+
}
|
|
1131
|
+
return false;
|
|
167
1132
|
}
|
|
168
1133
|
|
|
1134
|
+
class ggml_webgpu_shader_lib {
|
|
1135
|
+
wgpu::Device device;
|
|
1136
|
+
pre_wgsl::Preprocessor preprocessor;
|
|
1137
|
+
|
|
1138
|
+
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
|
|
1139
|
+
std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
|
|
1140
|
+
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
|
|
1141
|
+
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
|
|
1142
|
+
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
|
|
1143
|
+
std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
|
|
1144
|
+
row_norm_pipelines; // op/inplace
|
|
1145
|
+
|
|
1146
|
+
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
|
|
1147
|
+
get_rows_pipelines; // src_type, vectorized
|
|
1148
|
+
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
|
|
1149
|
+
unary_pipelines; // type/op/inplace
|
|
1150
|
+
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
|
|
1151
|
+
scale_pipelines; // inplace
|
|
1152
|
+
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
|
|
1153
|
+
solve_tri_pipelines; // type
|
|
1154
|
+
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
|
|
1155
|
+
ssm_conv_pipelines; // type/vectorized
|
|
1156
|
+
std::unordered_map<ggml_webgpu_ssm_scan_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_scan_pipeline_key_hash>
|
|
1157
|
+
ssm_scan_pipelines; // type/d_state
|
|
1158
|
+
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
|
|
1159
|
+
webgpu_pipeline,
|
|
1160
|
+
ggml_webgpu_gated_delta_net_pipeline_key_hash>
|
|
1161
|
+
gated_delta_net_pipelines; // type/S_v/kda
|
|
1162
|
+
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
|
|
1163
|
+
pad_pipelines; // circular/non-circular
|
|
1164
|
+
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
|
|
1165
|
+
binary_pipelines; // type/op/inplace/overlap/src_overlap
|
|
1166
|
+
std::unordered_map<ggml_webgpu_add_id_pipeline_key, webgpu_pipeline, ggml_webgpu_add_id_pipeline_key_hash>
|
|
1167
|
+
add_id_pipelines; // inplace
|
|
1168
|
+
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
|
|
1169
|
+
concat_pipelines; // type
|
|
1170
|
+
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
|
|
1171
|
+
repeat_pipelines; // type
|
|
1172
|
+
std::unordered_map<ggml_webgpu_flash_attn_vec_pipeline_key,
|
|
1173
|
+
webgpu_pipeline,
|
|
1174
|
+
ggml_webgpu_flash_attn_vec_pipeline_key_hash>
|
|
1175
|
+
flash_attn_vec_pipelines;
|
|
1176
|
+
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
|
1177
|
+
flash_attn_pipelines;
|
|
1178
|
+
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
|
|
1179
|
+
webgpu_pipeline,
|
|
1180
|
+
ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
|
|
1181
|
+
flash_attn_vec_reduce_pipelines;
|
|
1182
|
+
std::unordered_map<ggml_webgpu_flash_attn_blk_pipeline_key,
|
|
1183
|
+
webgpu_pipeline,
|
|
1184
|
+
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
|
|
1185
|
+
flash_attn_blk_pipelines;
|
|
1186
|
+
std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
|
|
1187
|
+
mul_mat_vec_pipelines; // fast mat-vec (n==1)
|
|
1188
|
+
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
|
|
1189
|
+
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
|
|
1190
|
+
std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash>
|
|
1191
|
+
quantize_q8_pipelines;
|
|
1192
|
+
std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
|
|
1193
|
+
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
|
|
1194
|
+
mul_mat_id_pipelines; // src0_type/src1_type
|
|
1195
|
+
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
|
|
1196
|
+
mul_mat_id_vec_pipelines; // src0_type/src1_type
|
|
1197
|
+
|
|
1198
|
+
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
|
1199
|
+
set_rows_pipelines;
|
|
1200
|
+
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
|
|
1201
|
+
std::unordered_map<ggml_webgpu_cpy_pipeline_key, webgpu_pipeline, ggml_webgpu_cpy_pipeline_key_hash> cpy_pipelines;
|
|
1202
|
+
std::unordered_map<ggml_webgpu_glu_pipeline_key, webgpu_pipeline, ggml_webgpu_glu_pipeline_key_hash> glu_pipelines;
|
|
1203
|
+
std::unordered_map<ggml_webgpu_rope_pipeline_key, webgpu_pipeline, ggml_webgpu_rope_pipeline_key_hash>
|
|
1204
|
+
rope_pipelines;
|
|
1205
|
+
std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
|
|
1206
|
+
soft_max_pipelines;
|
|
1207
|
+
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
|
|
1208
|
+
conv2d_pipelines;
|
|
1209
|
+
std::unordered_map<ggml_webgpu_im2col_pipeline_key, webgpu_pipeline, ggml_webgpu_im2col_pipeline_key_hash>
|
|
1210
|
+
im2col_pipelines;
|
|
1211
|
+
|
|
1212
|
+
std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key,
|
|
1213
|
+
webgpu_pipeline,
|
|
1214
|
+
ggml_webgpu_rms_norm_mul_pipeline_key_hash>
|
|
1215
|
+
rms_norm_mul_pipelines;
|
|
1216
|
+
std::unordered_map<ggml_webgpu_upscale_pipeline_key, webgpu_pipeline, ggml_webgpu_upscale_pipeline_key_hash>
|
|
1217
|
+
upscale_pipelines;
|
|
1218
|
+
|
|
1219
|
+
public:
|
|
1220
|
+
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
|
1221
|
+
|
|
1222
|
+
webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1223
|
+
auto it = sum_rows_pipelines.find(1);
|
|
1224
|
+
if (it != sum_rows_pipelines.end()) {
|
|
1225
|
+
return it->second;
|
|
1226
|
+
}
|
|
1227
|
+
std::vector<std::string> defines;
|
|
1228
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1229
|
+
|
|
1230
|
+
auto processed = preprocessor.preprocess(wgsl_sum_rows, defines);
|
|
1231
|
+
sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows");
|
|
1232
|
+
return sum_rows_pipelines[1];
|
|
1233
|
+
}
|
|
1234
|
+
|
|
1235
|
+
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1236
|
+
ggml_webgpu_row_norm_pipeline_key key = {};
|
|
1237
|
+
key.op = context.dst->op;
|
|
1238
|
+
key.src_type = context.src0->type;
|
|
1239
|
+
key.dst_type = context.dst->type;
|
|
1240
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
1241
|
+
|
|
1242
|
+
auto it = row_norm_pipelines.find(key);
|
|
1243
|
+
if (it != row_norm_pipelines.end()) {
|
|
1244
|
+
return it->second;
|
|
1245
|
+
}
|
|
1246
|
+
std::vector<std::string> defines;
|
|
1247
|
+
std::string variant;
|
|
1248
|
+
|
|
1249
|
+
switch (key.op) {
|
|
1250
|
+
case GGML_OP_RMS_NORM:
|
|
1251
|
+
defines.push_back("RMS_NORM");
|
|
1252
|
+
variant = "rms_norm";
|
|
1253
|
+
break;
|
|
1254
|
+
case GGML_OP_NORM:
|
|
1255
|
+
defines.push_back("NORM");
|
|
1256
|
+
variant = "norm";
|
|
1257
|
+
break;
|
|
1258
|
+
case GGML_OP_L2_NORM:
|
|
1259
|
+
defines.push_back("L2_NORM");
|
|
1260
|
+
variant = "l2_norm";
|
|
1261
|
+
break;
|
|
1262
|
+
default:
|
|
1263
|
+
GGML_ABORT("Unsupported op for row_norm shader");
|
|
1264
|
+
}
|
|
1265
|
+
|
|
1266
|
+
if (key.inplace) {
|
|
1267
|
+
defines.push_back("INPLACE");
|
|
1268
|
+
variant += "_inplace";
|
|
1269
|
+
}
|
|
1270
|
+
|
|
1271
|
+
if (key.src_type == GGML_TYPE_F32) {
|
|
1272
|
+
defines.push_back("SRC_F32");
|
|
1273
|
+
variant += "_src_f32";
|
|
1274
|
+
} else if (key.src_type == GGML_TYPE_F16) {
|
|
1275
|
+
defines.push_back("SRC_F16");
|
|
1276
|
+
variant += "_src_f16";
|
|
1277
|
+
}
|
|
1278
|
+
|
|
1279
|
+
if (key.dst_type == GGML_TYPE_F32) {
|
|
1280
|
+
defines.push_back("DST_F32");
|
|
1281
|
+
variant += "_dst_f32";
|
|
1282
|
+
} else if (key.dst_type == GGML_TYPE_F16) {
|
|
1283
|
+
defines.push_back("DST_F16");
|
|
1284
|
+
variant += "_dst_f16";
|
|
1285
|
+
}
|
|
1286
|
+
|
|
1287
|
+
const uint32_t row_norm_wg_size = 128u;
|
|
1288
|
+
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
|
|
1289
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
1290
|
+
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
|
|
1291
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1292
|
+
decisions->wg_size = wg_size;
|
|
1293
|
+
decisions->inplace = key.inplace;
|
|
1294
|
+
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1295
|
+
row_norm_pipelines[key].context = decisions;
|
|
1296
|
+
return row_norm_pipelines[key];
|
|
1297
|
+
}
|
|
1298
|
+
|
|
1299
|
+
webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1300
|
+
bool vec4 = context.src0->ne[0] % 4 == 0;
|
|
1301
|
+
|
|
1302
|
+
auto it = argmax_pipelines.find(vec4);
|
|
1303
|
+
if (it != argmax_pipelines.end()) {
|
|
1304
|
+
return it->second;
|
|
1305
|
+
}
|
|
1306
|
+
std::string variant = "argmax";
|
|
1307
|
+
std::vector<std::string> defines;
|
|
1308
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1309
|
+
if (vec4) {
|
|
1310
|
+
defines.push_back("VEC4");
|
|
1311
|
+
variant += "_vec4";
|
|
1312
|
+
}
|
|
1313
|
+
|
|
1314
|
+
auto processed = preprocessor.preprocess(wgsl_argmax, defines);
|
|
1315
|
+
argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1316
|
+
return argmax_pipelines.at(vec4);
|
|
1317
|
+
}
|
|
1318
|
+
|
|
1319
|
+
webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1320
|
+
const bool quantized = ggml_is_quantized(context.dst->type);
|
|
1321
|
+
ggml_webgpu_set_rows_pipeline_key key = {};
|
|
1322
|
+
key.dst_type = context.dst->type;
|
|
1323
|
+
key.vec4 =
|
|
1324
|
+
(context.dst->type == GGML_TYPE_F32 || context.dst->type == GGML_TYPE_F16) && context.src0->ne[0] % 4 == 0;
|
|
1325
|
+
key.i64_idx = context.src1->type == GGML_TYPE_I64;
|
|
1326
|
+
key.pair_blocks = quantized && ((context.src0->ne[0] / ggml_blck_size(context.dst->type)) % 2 == 0);
|
|
1327
|
+
|
|
1328
|
+
auto it = set_rows_pipelines.find(key);
|
|
1329
|
+
if (it != set_rows_pipelines.end()) {
|
|
1330
|
+
return it->second;
|
|
1331
|
+
}
|
|
1332
|
+
|
|
1333
|
+
std::vector<std::string> defines;
|
|
1334
|
+
std::string variant = "set_rows";
|
|
1335
|
+
|
|
1336
|
+
switch (context.dst->type) {
|
|
1337
|
+
case GGML_TYPE_F32:
|
|
1338
|
+
defines.push_back("DST_F32");
|
|
1339
|
+
variant += "_dstf32";
|
|
1340
|
+
break;
|
|
1341
|
+
case GGML_TYPE_F16:
|
|
1342
|
+
defines.push_back("DST_F16");
|
|
1343
|
+
variant += "_dstf16";
|
|
1344
|
+
break;
|
|
1345
|
+
case GGML_TYPE_Q8_0:
|
|
1346
|
+
defines.push_back("DST_Q8_0");
|
|
1347
|
+
variant += "_dstq8_0";
|
|
1348
|
+
break;
|
|
1349
|
+
case GGML_TYPE_Q4_0:
|
|
1350
|
+
defines.push_back("DST_Q4_0");
|
|
1351
|
+
variant += "_dstq4_0";
|
|
1352
|
+
break;
|
|
1353
|
+
default:
|
|
1354
|
+
GGML_ABORT("Unsupported dst type for set_rows shader");
|
|
1355
|
+
}
|
|
1356
|
+
|
|
1357
|
+
if (key.vec4) {
|
|
1358
|
+
defines.push_back("VEC4");
|
|
1359
|
+
variant += "_vec4";
|
|
1360
|
+
}
|
|
1361
|
+
if (key.i64_idx) {
|
|
1362
|
+
defines.push_back("I64_IDX");
|
|
1363
|
+
variant += "_i64idx";
|
|
1364
|
+
}
|
|
1365
|
+
if (key.pair_blocks) {
|
|
1366
|
+
defines.push_back("PAIR_BLOCKS");
|
|
1367
|
+
variant += "_pair_blocks";
|
|
1368
|
+
}
|
|
1369
|
+
|
|
1370
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1371
|
+
|
|
1372
|
+
const auto & shader_source = quantized ? wgsl_set_rows_quant : wgsl_set_rows;
|
|
1373
|
+
auto processed = preprocessor.preprocess(shader_source, defines);
|
|
1374
|
+
auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
|
|
1375
|
+
decisions->vec4 = key.vec4;
|
|
1376
|
+
decisions->i64_idx = key.i64_idx;
|
|
1377
|
+
decisions->pair_blocks = key.pair_blocks;
|
|
1378
|
+
decisions->wg_size = context.max_wg_size;
|
|
1379
|
+
set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1380
|
+
set_rows_pipelines[key].context = decisions;
|
|
1381
|
+
return set_rows_pipelines[key];
|
|
1382
|
+
}
|
|
1383
|
+
|
|
1384
|
+
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1385
|
+
ggml_webgpu_set_pipeline_key key = {};
|
|
1386
|
+
key.type = context.dst->type;
|
|
1387
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
1388
|
+
|
|
1389
|
+
auto it = set_pipelines.find(key);
|
|
1390
|
+
if (it != set_pipelines.end()) {
|
|
1391
|
+
return it->second;
|
|
1392
|
+
}
|
|
1393
|
+
|
|
1394
|
+
std::vector<std::string> defines;
|
|
1395
|
+
std::string variant = "set";
|
|
1396
|
+
|
|
1397
|
+
switch (key.type) {
|
|
1398
|
+
case GGML_TYPE_F32:
|
|
1399
|
+
defines.push_back("TYPE_F32");
|
|
1400
|
+
variant += "_f32";
|
|
1401
|
+
break;
|
|
1402
|
+
case GGML_TYPE_I32:
|
|
1403
|
+
defines.push_back("TYPE_I32");
|
|
1404
|
+
variant += "_i32";
|
|
1405
|
+
break;
|
|
1406
|
+
default:
|
|
1407
|
+
GGML_ABORT("Unsupported type for set shader");
|
|
1408
|
+
}
|
|
1409
|
+
|
|
1410
|
+
if (key.inplace) {
|
|
1411
|
+
defines.push_back("INPLACE");
|
|
1412
|
+
variant += "_inplace";
|
|
1413
|
+
}
|
|
1414
|
+
|
|
1415
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1416
|
+
|
|
1417
|
+
auto processed = preprocessor.preprocess(wgsl_set, defines);
|
|
1418
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1419
|
+
decisions->wg_size = context.max_wg_size;
|
|
1420
|
+
decisions->inplace = key.inplace;
|
|
1421
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1422
|
+
pipeline.context = decisions;
|
|
1423
|
+
set_pipelines[key] = pipeline;
|
|
1424
|
+
return set_pipelines[key];
|
|
1425
|
+
}
|
|
1426
|
+
|
|
1427
|
+
webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1428
|
+
auto it = cumsum_pipelines.find(1);
|
|
1429
|
+
if (it != cumsum_pipelines.end()) {
|
|
1430
|
+
return it->second;
|
|
1431
|
+
}
|
|
1432
|
+
|
|
1433
|
+
std::vector<std::string> defines;
|
|
1434
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1435
|
+
|
|
1436
|
+
auto processed = preprocessor.preprocess(wgsl_cumsum, defines);
|
|
1437
|
+
cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "cumsum");
|
|
1438
|
+
return cumsum_pipelines[1];
|
|
1439
|
+
}
|
|
1440
|
+
|
|
1441
|
+
webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1442
|
+
bool is_top_k = context.dst->op == GGML_OP_TOP_K;
|
|
1443
|
+
// ascending order is 0, descending order is 1
|
|
1444
|
+
const int32_t order =
|
|
1445
|
+
is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
|
|
1446
|
+
|
|
1447
|
+
auto it = argsort_pipelines.find(order);
|
|
1448
|
+
if (it != argsort_pipelines.end()) {
|
|
1449
|
+
return it->second;
|
|
1450
|
+
}
|
|
1451
|
+
|
|
1452
|
+
std::vector<std::string> defines;
|
|
1453
|
+
std::string variant = "argsort";
|
|
1454
|
+
defines.push_back(std::string("ORDER=") + std::to_string(order));
|
|
1455
|
+
variant += std::string("_order") + std::to_string(order);
|
|
1456
|
+
uint32_t wg_size = 1;
|
|
1457
|
+
while (wg_size * 2 <= context.max_wg_size &&
|
|
1458
|
+
wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
|
|
1459
|
+
wg_size *= 2;
|
|
1460
|
+
}
|
|
1461
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
1462
|
+
auto processed = preprocessor.preprocess(wgsl_argsort, defines);
|
|
1463
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1464
|
+
decisions->wg_size = wg_size;
|
|
1465
|
+
argsort_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1466
|
+
argsort_pipelines[order].context = decisions;
|
|
1467
|
+
return argsort_pipelines[order];
|
|
1468
|
+
}
|
|
1469
|
+
|
|
1470
|
+
webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1471
|
+
bool is_top_k = context.dst->op == GGML_OP_TOP_K;
|
|
1472
|
+
// ascending order is 0, descending order is 1
|
|
1473
|
+
const int32_t order =
|
|
1474
|
+
is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
|
|
1475
|
+
|
|
1476
|
+
auto it = argsort_merge_pipelines.find(order);
|
|
1477
|
+
if (it != argsort_merge_pipelines.end()) {
|
|
1478
|
+
return it->second;
|
|
1479
|
+
}
|
|
1480
|
+
|
|
1481
|
+
std::vector<std::string> defines;
|
|
1482
|
+
std::string variant = "argsort_merge";
|
|
1483
|
+
defines.push_back(std::string("ORDER=") + std::to_string(order));
|
|
1484
|
+
variant += std::string("_order") + std::to_string(order);
|
|
1485
|
+
uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
|
|
1486
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
1487
|
+
|
|
1488
|
+
auto processed = preprocessor.preprocess(wgsl_argsort_merge, defines);
|
|
1489
|
+
argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1490
|
+
return argsort_merge_pipelines[order];
|
|
1491
|
+
}
|
|
1492
|
+
|
|
1493
|
+
webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1494
|
+
const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0;
|
|
1495
|
+
ggml_webgpu_get_rows_pipeline_key key = {};
|
|
1496
|
+
key.src_type = context.src0->type;
|
|
1497
|
+
key.vectorized = (int) vectorized;
|
|
1498
|
+
|
|
1499
|
+
auto it = get_rows_pipelines.find(key);
|
|
1500
|
+
if (it != get_rows_pipelines.end()) {
|
|
1501
|
+
return it->second;
|
|
1502
|
+
}
|
|
1503
|
+
|
|
1504
|
+
std::vector<std::string> defines;
|
|
1505
|
+
std::string variant = "get_rows";
|
|
1506
|
+
|
|
1507
|
+
const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type);
|
|
1508
|
+
const char * type_str = type_traits->type_name;
|
|
1509
|
+
|
|
1510
|
+
switch (key.src_type) {
|
|
1511
|
+
case GGML_TYPE_F32:
|
|
1512
|
+
defines.push_back("FLOAT_PARALLEL");
|
|
1513
|
+
if (key.vectorized) {
|
|
1514
|
+
defines.push_back("F32_VEC");
|
|
1515
|
+
defines.push_back("SRC_TYPE=vec4<f32>");
|
|
1516
|
+
defines.push_back("DST_TYPE=vec4<f32>");
|
|
1517
|
+
defines.push_back("BLOCK_SIZE=4u");
|
|
1518
|
+
} else {
|
|
1519
|
+
defines.push_back("F32");
|
|
1520
|
+
defines.push_back("SRC_TYPE=f32");
|
|
1521
|
+
defines.push_back("DST_TYPE=f32");
|
|
1522
|
+
defines.push_back("BLOCK_SIZE=1u");
|
|
1523
|
+
}
|
|
1524
|
+
variant += "_f32";
|
|
1525
|
+
break;
|
|
1526
|
+
case GGML_TYPE_F16:
|
|
1527
|
+
defines.push_back("FLOAT_PARALLEL");
|
|
1528
|
+
defines.push_back("F16");
|
|
1529
|
+
defines.push_back("SRC_TYPE=f16");
|
|
1530
|
+
defines.push_back("DST_TYPE=f32");
|
|
1531
|
+
defines.push_back("BLOCK_SIZE=1u");
|
|
1532
|
+
variant += "_f16";
|
|
1533
|
+
break;
|
|
1534
|
+
case GGML_TYPE_I32:
|
|
1535
|
+
defines.push_back("FLOAT_PARALLEL");
|
|
1536
|
+
defines.push_back("I32");
|
|
1537
|
+
defines.push_back("SRC_TYPE=i32");
|
|
1538
|
+
defines.push_back("DST_TYPE=i32");
|
|
1539
|
+
defines.push_back("BLOCK_SIZE=1u");
|
|
1540
|
+
variant += "_i32";
|
|
1541
|
+
break;
|
|
1542
|
+
default:
|
|
1543
|
+
{
|
|
1544
|
+
std::string type_upper = type_str;
|
|
1545
|
+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
1546
|
+
|
|
1547
|
+
switch (key.src_type) {
|
|
1548
|
+
case GGML_TYPE_Q1_0:
|
|
1549
|
+
case GGML_TYPE_Q4_0:
|
|
1550
|
+
case GGML_TYPE_Q5_0:
|
|
1551
|
+
case GGML_TYPE_Q8_0:
|
|
1552
|
+
case GGML_TYPE_Q3_K:
|
|
1553
|
+
case GGML_TYPE_Q6_K:
|
|
1554
|
+
case GGML_TYPE_IQ2_XXS:
|
|
1555
|
+
case GGML_TYPE_IQ2_XS:
|
|
1556
|
+
case GGML_TYPE_IQ2_S:
|
|
1557
|
+
case GGML_TYPE_IQ3_XXS:
|
|
1558
|
+
case GGML_TYPE_IQ3_S:
|
|
1559
|
+
case GGML_TYPE_IQ1_S:
|
|
1560
|
+
case GGML_TYPE_IQ4_NL:
|
|
1561
|
+
case GGML_TYPE_MXFP4:
|
|
1562
|
+
{
|
|
1563
|
+
// Quantized types using u32 buffers for portability.
|
|
1564
|
+
defines.push_back("SRC_TYPE=u32");
|
|
1565
|
+
defines.push_back("U32_DEQUANT_HELPERS");
|
|
1566
|
+
break;
|
|
1567
|
+
}
|
|
1568
|
+
default:
|
|
1569
|
+
{
|
|
1570
|
+
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
|
1571
|
+
}
|
|
1572
|
+
}
|
|
1573
|
+
|
|
1574
|
+
defines.push_back("BYTE_HELPERS");
|
|
1575
|
+
defines.push_back(type_upper + "_T");
|
|
1576
|
+
defines.push_back(type_upper);
|
|
1577
|
+
defines.push_back(type_upper + "_SCALE_MIN");
|
|
1578
|
+
defines.push_back(type_upper + "_TABLES");
|
|
1579
|
+
defines.push_back(type_upper + "_GRID");
|
|
1580
|
+
defines.push_back(type_upper + "_LUT");
|
|
1581
|
+
|
|
1582
|
+
variant += "_";
|
|
1583
|
+
variant += type_str;
|
|
1584
|
+
|
|
1585
|
+
defines.push_back("DST_TYPE=f32");
|
|
1586
|
+
|
|
1587
|
+
if (key.src_type == GGML_TYPE_Q1_0) {
|
|
1588
|
+
defines.push_back("BLOCK_SIZE=128u");
|
|
1589
|
+
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
|
1590
|
+
key.src_type == GGML_TYPE_IQ4_NL || key.src_type == GGML_TYPE_MXFP4) {
|
|
1591
|
+
defines.push_back("BLOCK_SIZE=32u");
|
|
1592
|
+
} else if (key.src_type >= GGML_TYPE_Q2_K) {
|
|
1593
|
+
defines.push_back("BLOCK_SIZE=256u");
|
|
1594
|
+
} else {
|
|
1595
|
+
defines.push_back("BLOCK_SIZE=1u");
|
|
1596
|
+
}
|
|
1597
|
+
break;
|
|
1598
|
+
}
|
|
1599
|
+
}
|
|
1600
|
+
|
|
1601
|
+
if (key.vectorized) {
|
|
1602
|
+
variant += "_vec";
|
|
1603
|
+
}
|
|
1604
|
+
|
|
1605
|
+
defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size));
|
|
1606
|
+
|
|
1607
|
+
auto processed = preprocessor.preprocess(wgsl_get_rows, defines);
|
|
1608
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1609
|
+
decisions->wg_size = context.max_wg_size;
|
|
1610
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1611
|
+
pipeline.context = decisions;
|
|
1612
|
+
get_rows_pipelines[key] = pipeline;
|
|
1613
|
+
return get_rows_pipelines[key];
|
|
1614
|
+
}
|
|
1615
|
+
|
|
1616
|
+
webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1617
|
+
ggml_webgpu_scale_pipeline_key key = {};
|
|
1618
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
1619
|
+
|
|
1620
|
+
auto it = scale_pipelines.find(key);
|
|
1621
|
+
if (it != scale_pipelines.end()) {
|
|
1622
|
+
return it->second;
|
|
1623
|
+
}
|
|
1624
|
+
|
|
1625
|
+
std::vector<std::string> defines;
|
|
1626
|
+
std::string variant = "scale";
|
|
1627
|
+
|
|
1628
|
+
if (key.inplace) {
|
|
1629
|
+
defines.push_back("INPLACE");
|
|
1630
|
+
variant += "_inplace";
|
|
1631
|
+
}
|
|
1632
|
+
|
|
1633
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1634
|
+
|
|
1635
|
+
auto processed = preprocessor.preprocess(wgsl_scale, defines);
|
|
1636
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1637
|
+
decisions->wg_size = context.max_wg_size;
|
|
1638
|
+
decisions->inplace = key.inplace;
|
|
1639
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1640
|
+
pipeline.context = decisions;
|
|
1641
|
+
scale_pipelines[key] = pipeline;
|
|
1642
|
+
return scale_pipelines[key];
|
|
1643
|
+
}
|
|
1644
|
+
|
|
1645
|
+
webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1646
|
+
ggml_webgpu_solve_tri_pipeline_key key = {};
|
|
1647
|
+
key.type = context.dst->type;
|
|
1648
|
+
key.n = (int) context.src0->ne[0];
|
|
1649
|
+
key.k = (int) context.src1->ne[0];
|
|
1650
|
+
|
|
1651
|
+
auto it = solve_tri_pipelines.find(key);
|
|
1652
|
+
if (it != solve_tri_pipelines.end()) {
|
|
1653
|
+
return it->second;
|
|
1654
|
+
}
|
|
1655
|
+
|
|
1656
|
+
std::vector<std::string> defines;
|
|
1657
|
+
std::string variant = "solve_tri";
|
|
1658
|
+
|
|
1659
|
+
switch (key.type) {
|
|
1660
|
+
case GGML_TYPE_F32:
|
|
1661
|
+
variant += "_f32";
|
|
1662
|
+
break;
|
|
1663
|
+
default:
|
|
1664
|
+
GGML_ABORT("Unsupported type for solve_tri shader");
|
|
1665
|
+
}
|
|
1666
|
+
|
|
1667
|
+
const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size);
|
|
1668
|
+
const uint32_t k_tile = wg_size;
|
|
1669
|
+
const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES;
|
|
1670
|
+
const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row);
|
|
1671
|
+
|
|
1672
|
+
defines.push_back(std::string("N=") + std::to_string(key.n));
|
|
1673
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
1674
|
+
defines.push_back(std::string("K_TILE=") + std::to_string(k_tile));
|
|
1675
|
+
defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n));
|
|
1676
|
+
|
|
1677
|
+
auto processed = preprocessor.preprocess(wgsl_solve_tri, defines);
|
|
1678
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1679
|
+
decisions->wg_size = wg_size;
|
|
1680
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1681
|
+
pipeline.context = decisions;
|
|
1682
|
+
solve_tri_pipelines[key] = pipeline;
|
|
1683
|
+
return solve_tri_pipelines[key];
|
|
1684
|
+
}
|
|
1685
|
+
|
|
1686
|
+
webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1687
|
+
ggml_webgpu_ssm_conv_pipeline_key key = {};
|
|
1688
|
+
key.type = context.dst->type;
|
|
1689
|
+
key.vectorized = context.src1->ne[0] == 4;
|
|
1690
|
+
|
|
1691
|
+
auto it = ssm_conv_pipelines.find(key);
|
|
1692
|
+
if (it != ssm_conv_pipelines.end()) {
|
|
1693
|
+
return it->second;
|
|
1694
|
+
}
|
|
1695
|
+
|
|
1696
|
+
std::vector<std::string> defines;
|
|
1697
|
+
std::string variant = "ssm_conv";
|
|
1698
|
+
|
|
1699
|
+
switch (key.type) {
|
|
1700
|
+
case GGML_TYPE_F32:
|
|
1701
|
+
variant += "_f32";
|
|
1702
|
+
break;
|
|
1703
|
+
default:
|
|
1704
|
+
GGML_ABORT("Unsupported type for ssm_conv shader");
|
|
1705
|
+
}
|
|
1706
|
+
|
|
1707
|
+
if (key.vectorized) {
|
|
1708
|
+
defines.push_back("VECTORIZED");
|
|
1709
|
+
variant += "_vec4";
|
|
1710
|
+
}
|
|
1711
|
+
|
|
1712
|
+
constexpr uint32_t block_size = 32u;
|
|
1713
|
+
constexpr uint32_t tokens_per_wg = 8u;
|
|
1714
|
+
|
|
1715
|
+
defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u");
|
|
1716
|
+
defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u");
|
|
1717
|
+
|
|
1718
|
+
auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines);
|
|
1719
|
+
auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>();
|
|
1720
|
+
decisions->block_size = block_size;
|
|
1721
|
+
decisions->tokens_per_wg = tokens_per_wg;
|
|
1722
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1723
|
+
pipeline.context = decisions;
|
|
1724
|
+
ssm_conv_pipelines[key] = pipeline;
|
|
1725
|
+
return ssm_conv_pipelines[key];
|
|
1726
|
+
}
|
|
1727
|
+
|
|
1728
|
+
webgpu_pipeline get_ssm_scan_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1729
|
+
ggml_webgpu_ssm_scan_pipeline_key key = {};
|
|
1730
|
+
key.type = context.dst->type;
|
|
1731
|
+
key.d_state = (int) context.src0->ne[0];
|
|
1732
|
+
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
|
|
1733
|
+
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
|
1734
|
+
|
|
1735
|
+
auto it = ssm_scan_pipelines.find(key);
|
|
1736
|
+
if (it != ssm_scan_pipelines.end()) {
|
|
1737
|
+
return it->second;
|
|
1738
|
+
}
|
|
1739
|
+
|
|
1740
|
+
std::vector<std::string> defines;
|
|
1741
|
+
std::string variant = "ssm_scan";
|
|
1742
|
+
|
|
1743
|
+
switch (key.type) {
|
|
1744
|
+
case GGML_TYPE_F32:
|
|
1745
|
+
variant += "_f32";
|
|
1746
|
+
break;
|
|
1747
|
+
default:
|
|
1748
|
+
GGML_ABORT("Unsupported type for ssm_scan shader");
|
|
1749
|
+
}
|
|
1750
|
+
|
|
1751
|
+
const uint32_t wg_size = (uint32_t) key.d_state;
|
|
1752
|
+
|
|
1753
|
+
constexpr uint32_t tokens_per_tile = 4u;
|
|
1754
|
+
|
|
1755
|
+
defines.push_back("WG_SIZE=" + std::to_string(wg_size) + "u");
|
|
1756
|
+
defines.push_back("TOKENS_PER_TILE=" + std::to_string(tokens_per_tile) + "u");
|
|
1757
|
+
|
|
1758
|
+
if (context.supports_subgroups) {
|
|
1759
|
+
defines.push_back("USE_SUBGROUP_REDUCTION");
|
|
1760
|
+
variant += "_sg_reduce";
|
|
1761
|
+
} else {
|
|
1762
|
+
variant += "_wg_reduce";
|
|
1763
|
+
}
|
|
1764
|
+
|
|
1765
|
+
if (key.xbc_overlap) {
|
|
1766
|
+
defines.push_back("XBC_OVERLAP");
|
|
1767
|
+
}
|
|
1768
|
+
|
|
1769
|
+
variant += "_d" + std::to_string(key.d_state);
|
|
1770
|
+
|
|
1771
|
+
auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines);
|
|
1772
|
+
auto decisions = std::make_shared<ggml_webgpu_ssm_scan_shader_decisions>();
|
|
1773
|
+
decisions->wg_size = wg_size;
|
|
1774
|
+
decisions->tokens_per_tile = tokens_per_tile;
|
|
1775
|
+
decisions->xbc_overlap = key.xbc_overlap;
|
|
1776
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1777
|
+
pipeline.context = decisions;
|
|
1778
|
+
ssm_scan_pipelines[key] = pipeline;
|
|
1779
|
+
return ssm_scan_pipelines[key];
|
|
1780
|
+
}
|
|
1781
|
+
|
|
1782
|
+
webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1783
|
+
ggml_webgpu_gated_delta_net_pipeline_key key = {};
|
|
1784
|
+
key.type = context.dst->type;
|
|
1785
|
+
key.s_v = (int) context.src2->ne[0];
|
|
1786
|
+
key.kda = context.src3->ne[0] == context.src2->ne[0];
|
|
1787
|
+
|
|
1788
|
+
auto it = gated_delta_net_pipelines.find(key);
|
|
1789
|
+
if (it != gated_delta_net_pipelines.end()) {
|
|
1790
|
+
return it->second;
|
|
1791
|
+
}
|
|
1792
|
+
|
|
1793
|
+
std::vector<std::string> defines;
|
|
1794
|
+
std::string variant = "gated_delta_net";
|
|
1795
|
+
|
|
1796
|
+
switch (key.type) {
|
|
1797
|
+
case GGML_TYPE_F32:
|
|
1798
|
+
variant += "_f32";
|
|
1799
|
+
break;
|
|
1800
|
+
default:
|
|
1801
|
+
GGML_ABORT("Unsupported type for gated_delta_net shader");
|
|
1802
|
+
}
|
|
1803
|
+
|
|
1804
|
+
if (key.kda) {
|
|
1805
|
+
defines.push_back("KDA");
|
|
1806
|
+
variant += "_kda";
|
|
1807
|
+
}
|
|
1808
|
+
|
|
1809
|
+
defines.push_back("S_V=" + std::to_string(key.s_v) + "u");
|
|
1810
|
+
defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u");
|
|
1811
|
+
|
|
1812
|
+
auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines);
|
|
1813
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1814
|
+
gated_delta_net_pipelines[key] = pipeline;
|
|
1815
|
+
return gated_delta_net_pipelines[key];
|
|
1816
|
+
}
|
|
1817
|
+
|
|
1818
|
+
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1819
|
+
ggml_webgpu_pad_pipeline_key key = {};
|
|
1820
|
+
key.circular = ggml_get_op_params_i32(context.dst, 8) != 0;
|
|
1821
|
+
|
|
1822
|
+
auto it = pad_pipelines.find(key);
|
|
1823
|
+
if (it != pad_pipelines.end()) {
|
|
1824
|
+
return it->second;
|
|
1825
|
+
}
|
|
1826
|
+
|
|
1827
|
+
std::vector<std::string> defines;
|
|
1828
|
+
std::string variant = "pad";
|
|
1829
|
+
|
|
1830
|
+
if (key.circular) {
|
|
1831
|
+
defines.push_back("CIRCULAR");
|
|
1832
|
+
variant += "_circular";
|
|
1833
|
+
}
|
|
1834
|
+
|
|
1835
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1836
|
+
|
|
1837
|
+
auto processed = preprocessor.preprocess(wgsl_pad, defines);
|
|
1838
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1839
|
+
decisions->wg_size = context.max_wg_size;
|
|
1840
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1841
|
+
pipeline.context = decisions;
|
|
1842
|
+
pad_pipelines[key] = pipeline;
|
|
1843
|
+
return pad_pipelines[key];
|
|
1844
|
+
}
|
|
1845
|
+
|
|
1846
|
+
webgpu_pipeline get_quantize_q8_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1847
|
+
ggml_webgpu_quantize_q8_pipeline_key key = {};
|
|
1848
|
+
key.src0_type = context.src0->type;
|
|
1849
|
+
|
|
1850
|
+
auto it = quantize_q8_pipelines.find(key);
|
|
1851
|
+
if (it != quantize_q8_pipelines.end()) {
|
|
1852
|
+
return it->second;
|
|
1853
|
+
}
|
|
1854
|
+
const char * shader_src = wgsl_quantize_q8;
|
|
1855
|
+
std::vector<std::string> defines;
|
|
1856
|
+
std::string variant = "quantize_q8";
|
|
1857
|
+
|
|
1858
|
+
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
|
1859
|
+
|
|
1860
|
+
defines.push_back("SRC1_INNER_TYPE=f32");
|
|
1861
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
1862
|
+
|
|
1863
|
+
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
|
1864
|
+
std::string src0_name = src0_traits->type_name;
|
|
1865
|
+
std::string type_upper = src0_name;
|
|
1866
|
+
variant += "_" + src0_name;
|
|
1867
|
+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
1868
|
+
|
|
1869
|
+
defines.push_back("MUL_ACC_" + type_upper);
|
|
1870
|
+
defines.push_back("Q8_1_T");
|
|
1871
|
+
|
|
1872
|
+
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
|
|
1873
|
+
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
|
|
1874
|
+
|
|
1875
|
+
auto processed = preprocessor.preprocess(shader_src, defines);
|
|
1876
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1877
|
+
decisions->wg_size = wg_size;
|
|
1878
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1879
|
+
pipeline.context = decisions;
|
|
1880
|
+
quantize_q8_pipelines[key] = pipeline;
|
|
1881
|
+
return quantize_q8_pipelines[key];
|
|
1882
|
+
}
|
|
1883
|
+
|
|
1884
|
+
webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1885
|
+
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
|
|
1886
|
+
key.src0_type = context.src0->type;
|
|
1887
|
+
key.src1_type = context.src1->type;
|
|
1888
|
+
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
|
|
1889
|
+
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
|
1890
|
+
1 :
|
|
1891
|
+
0;
|
|
1892
|
+
key.use_mmvq =
|
|
1893
|
+
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
|
|
1894
|
+
|
|
1895
|
+
auto it = mul_mat_vec_pipelines.find(key);
|
|
1896
|
+
if (it != mul_mat_vec_pipelines.end()) {
|
|
1897
|
+
return it->second;
|
|
1898
|
+
}
|
|
1899
|
+
|
|
1900
|
+
std::vector<std::string> defines;
|
|
1901
|
+
std::string variant = "mul_mat_vec";
|
|
1902
|
+
const char * shader_src = wgsl_mul_mat_vec;
|
|
1903
|
+
|
|
1904
|
+
// src0 type (matrix row)
|
|
1905
|
+
switch (context.src0->type) {
|
|
1906
|
+
case GGML_TYPE_F32:
|
|
1907
|
+
defines.push_back("SRC0_INNER_TYPE=f32");
|
|
1908
|
+
defines.push_back("MUL_ACC_FLOAT");
|
|
1909
|
+
variant += "_f32";
|
|
1910
|
+
break;
|
|
1911
|
+
case GGML_TYPE_F16:
|
|
1912
|
+
defines.push_back("SRC0_INNER_TYPE=f16");
|
|
1913
|
+
defines.push_back("MUL_ACC_FLOAT");
|
|
1914
|
+
variant += "_f16";
|
|
1915
|
+
break;
|
|
1916
|
+
default:
|
|
1917
|
+
{
|
|
1918
|
+
// Quantized types: use helpers but accumulate in f16
|
|
1919
|
+
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
|
1920
|
+
std::string src0_name = src0_traits->type_name;
|
|
1921
|
+
std::string type_upper = src0_name;
|
|
1922
|
+
variant += "_" + src0_name;
|
|
1923
|
+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
1924
|
+
|
|
1925
|
+
defines.push_back("BYTE_HELPERS");
|
|
1926
|
+
defines.push_back("MUL_ACC_" + type_upper);
|
|
1927
|
+
defines.push_back("U32_DEQUANT_HELPERS");
|
|
1928
|
+
defines.push_back("SRC0_INNER_TYPE=u32");
|
|
1929
|
+
switch (context.src0->type) {
|
|
1930
|
+
case GGML_TYPE_Q8_0:
|
|
1931
|
+
case GGML_TYPE_Q4_0:
|
|
1932
|
+
case GGML_TYPE_Q4_1:
|
|
1933
|
+
if (key.use_mmvq) {
|
|
1934
|
+
defines.push_back("LEGACY_QUANTS");
|
|
1935
|
+
}
|
|
1936
|
+
break;
|
|
1937
|
+
case GGML_TYPE_Q2_K:
|
|
1938
|
+
case GGML_TYPE_Q4_K:
|
|
1939
|
+
if (key.use_mmvq) {
|
|
1940
|
+
defines.push_back("K_QUANTS");
|
|
1941
|
+
}
|
|
1942
|
+
break;
|
|
1943
|
+
case GGML_TYPE_IQ1_S:
|
|
1944
|
+
case GGML_TYPE_IQ1_M:
|
|
1945
|
+
case GGML_TYPE_IQ2_S:
|
|
1946
|
+
case GGML_TYPE_IQ3_S:
|
|
1947
|
+
case GGML_TYPE_IQ4_NL:
|
|
1948
|
+
case GGML_TYPE_IQ4_XS:
|
|
1949
|
+
defines.push_back(type_upper + "_GRID");
|
|
1950
|
+
break;
|
|
1951
|
+
case GGML_TYPE_IQ2_XXS:
|
|
1952
|
+
case GGML_TYPE_IQ2_XS:
|
|
1953
|
+
case GGML_TYPE_IQ3_XXS:
|
|
1954
|
+
defines.push_back(type_upper + "_GRID");
|
|
1955
|
+
defines.push_back(type_upper + "_TABLES");
|
|
1956
|
+
break;
|
|
1957
|
+
case GGML_TYPE_MXFP4:
|
|
1958
|
+
defines.push_back(type_upper + "_LUT");
|
|
1959
|
+
break;
|
|
1960
|
+
default:
|
|
1961
|
+
break;
|
|
1962
|
+
}
|
|
1963
|
+
break;
|
|
1964
|
+
}
|
|
1965
|
+
}
|
|
1966
|
+
|
|
1967
|
+
// src1 type (vector)
|
|
1968
|
+
switch (context.src1->type) {
|
|
1969
|
+
case GGML_TYPE_F32:
|
|
1970
|
+
defines.push_back("SRC1_INNER_TYPE=f32");
|
|
1971
|
+
variant += "_f32";
|
|
1972
|
+
break;
|
|
1973
|
+
case GGML_TYPE_F16:
|
|
1974
|
+
defines.push_back("SRC1_INNER_TYPE=f16");
|
|
1975
|
+
variant += "_f16";
|
|
1976
|
+
break;
|
|
1977
|
+
default:
|
|
1978
|
+
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
|
|
1979
|
+
}
|
|
1980
|
+
|
|
1981
|
+
// VEC/SCALAR controls
|
|
1982
|
+
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
|
1983
|
+
|
|
1984
|
+
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
|
1985
|
+
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
|
|
1986
|
+
|
|
1987
|
+
if (key.src0_type == GGML_TYPE_Q1_0) {
|
|
1988
|
+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
|
1989
|
+
} else if (key.src0_type >= GGML_TYPE_Q2_K) {
|
|
1990
|
+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
|
|
1991
|
+
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
|
|
1992
|
+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
|
1993
|
+
}
|
|
1994
|
+
|
|
1995
|
+
if (key.use_mmvq) {
|
|
1996
|
+
defines.push_back("MMVQ");
|
|
1997
|
+
defines.push_back("Q8_1_T");
|
|
1998
|
+
}
|
|
1999
|
+
|
|
2000
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
2001
|
+
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
|
|
2002
|
+
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
|
|
2003
|
+
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
|
|
2004
|
+
if (key.vectorized) {
|
|
2005
|
+
variant += "_vectorized";
|
|
2006
|
+
}
|
|
2007
|
+
|
|
2008
|
+
auto processed = preprocessor.preprocess(shader_src, defines);
|
|
2009
|
+
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
|
|
2010
|
+
decisions->wg_size = wg_size;
|
|
2011
|
+
decisions->outputs_per_wg = outputs_per_wg;
|
|
2012
|
+
decisions->vec_size = key.vectorized ? 4 : 1;
|
|
2013
|
+
|
|
2014
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2015
|
+
pipeline.context = decisions;
|
|
2016
|
+
mul_mat_vec_pipelines[key] = pipeline;
|
|
2017
|
+
return mul_mat_vec_pipelines[key];
|
|
2018
|
+
}
|
|
2019
|
+
|
|
2020
|
+
webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2021
|
+
ggml_webgpu_mul_mat_pipeline_key key = {};
|
|
2022
|
+
key.src0_type = context.src0->type;
|
|
2023
|
+
key.src1_type = context.src1->type;
|
|
2024
|
+
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
|
|
2025
|
+
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
|
2026
|
+
1 :
|
|
2027
|
+
0;
|
|
2028
|
+
key.use_subgroup_matrix = context.supports_subgroup_matrix;
|
|
2029
|
+
|
|
2030
|
+
auto it = mul_mat_fast_pipelines.find(key);
|
|
2031
|
+
if (it != mul_mat_fast_pipelines.end()) {
|
|
2032
|
+
return it->second;
|
|
2033
|
+
}
|
|
2034
|
+
|
|
2035
|
+
const char * shader_src = key.use_subgroup_matrix ? wgsl_mul_mat_subgroup_matrix : wgsl_mul_mat_reg_tile;
|
|
2036
|
+
std::vector<std::string> defines;
|
|
2037
|
+
std::string variant = key.use_subgroup_matrix ? "mul_mat_subgroup_matrix" : "mul_mat_reg_tile";
|
|
2038
|
+
|
|
2039
|
+
// src1 type
|
|
2040
|
+
switch (context.src1->type) {
|
|
2041
|
+
case GGML_TYPE_F32:
|
|
2042
|
+
defines.push_back("SRC1_INNER_TYPE=f32");
|
|
2043
|
+
break;
|
|
2044
|
+
case GGML_TYPE_F16:
|
|
2045
|
+
defines.push_back("SRC1_INNER_TYPE=f16");
|
|
2046
|
+
break;
|
|
2047
|
+
default:
|
|
2048
|
+
GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
|
|
2049
|
+
}
|
|
2050
|
+
|
|
2051
|
+
// src0 type
|
|
2052
|
+
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
|
2053
|
+
const char * src0_name = src0_traits->type_name;
|
|
2054
|
+
|
|
2055
|
+
switch (context.src0->type) {
|
|
2056
|
+
case GGML_TYPE_F32:
|
|
2057
|
+
defines.push_back("SRC0_INNER_TYPE=f32");
|
|
2058
|
+
defines.push_back("FLOAT");
|
|
2059
|
+
defines.push_back("MUL_ACC_FLOAT");
|
|
2060
|
+
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
|
2061
|
+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
|
2062
|
+
variant += "_f32";
|
|
2063
|
+
break;
|
|
2064
|
+
case GGML_TYPE_F16:
|
|
2065
|
+
defines.push_back("SRC0_INNER_TYPE=f16");
|
|
2066
|
+
defines.push_back("FLOAT");
|
|
2067
|
+
defines.push_back("MUL_ACC_FLOAT");
|
|
2068
|
+
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
|
2069
|
+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
|
2070
|
+
variant += "_f16";
|
|
2071
|
+
break;
|
|
2072
|
+
default:
|
|
2073
|
+
{
|
|
2074
|
+
std::string type_upper = src0_name;
|
|
2075
|
+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
2076
|
+
|
|
2077
|
+
defines.push_back("BYTE_HELPERS");
|
|
2078
|
+
defines.push_back("MUL_ACC_" + type_upper);
|
|
2079
|
+
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
|
|
2080
|
+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
|
2081
|
+
defines.push_back("U32_DEQUANT_HELPERS");
|
|
2082
|
+
defines.push_back("SRC0_INNER_TYPE=u32");
|
|
2083
|
+
|
|
2084
|
+
switch (context.src0->type) {
|
|
2085
|
+
case GGML_TYPE_IQ1_S:
|
|
2086
|
+
case GGML_TYPE_IQ1_M:
|
|
2087
|
+
case GGML_TYPE_IQ4_NL:
|
|
2088
|
+
case GGML_TYPE_IQ4_XS:
|
|
2089
|
+
defines.push_back(type_upper + "_GRID");
|
|
2090
|
+
break;
|
|
2091
|
+
case GGML_TYPE_IQ2_XXS:
|
|
2092
|
+
case GGML_TYPE_IQ2_XS:
|
|
2093
|
+
case GGML_TYPE_IQ2_S:
|
|
2094
|
+
case GGML_TYPE_IQ3_XXS:
|
|
2095
|
+
case GGML_TYPE_IQ3_S:
|
|
2096
|
+
defines.push_back(type_upper + "_GRID");
|
|
2097
|
+
defines.push_back(type_upper + "_TABLES");
|
|
2098
|
+
break;
|
|
2099
|
+
case GGML_TYPE_MXFP4:
|
|
2100
|
+
defines.push_back(type_upper + "_LUT");
|
|
2101
|
+
break;
|
|
2102
|
+
default:
|
|
2103
|
+
break;
|
|
2104
|
+
}
|
|
2105
|
+
|
|
2106
|
+
variant += std::string("_") + src0_name;
|
|
2107
|
+
break;
|
|
2108
|
+
}
|
|
2109
|
+
}
|
|
2110
|
+
|
|
2111
|
+
// VEC/SCALAR controls
|
|
2112
|
+
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
|
2113
|
+
|
|
2114
|
+
const bool is_quant = ggml_is_quantized(context.src0->type);
|
|
2115
|
+
|
|
2116
|
+
uint32_t tile_k;
|
|
2117
|
+
if (key.use_subgroup_matrix) {
|
|
2118
|
+
tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT;
|
|
2119
|
+
} else {
|
|
2120
|
+
tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
|
2121
|
+
}
|
|
2122
|
+
|
|
2123
|
+
// Tiles
|
|
2124
|
+
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
|
2125
|
+
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
|
|
2126
|
+
|
|
2127
|
+
// Subgroup matrix specifics
|
|
2128
|
+
if (key.use_subgroup_matrix) {
|
|
2129
|
+
defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
|
|
2130
|
+
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
|
|
2131
|
+
defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
|
|
2132
|
+
defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
|
|
2133
|
+
defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u");
|
|
2134
|
+
defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u");
|
|
2135
|
+
defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u");
|
|
2136
|
+
defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u");
|
|
2137
|
+
defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u");
|
|
2138
|
+
}
|
|
2139
|
+
|
|
2140
|
+
// variant suffix for src1 type
|
|
2141
|
+
variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
|
|
2142
|
+
if (key.vectorized) {
|
|
2143
|
+
variant += "_vectorized";
|
|
2144
|
+
}
|
|
2145
|
+
|
|
2146
|
+
if (!key.use_subgroup_matrix) {
|
|
2147
|
+
defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
|
|
2148
|
+
defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
|
|
2149
|
+
defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
|
|
2150
|
+
}
|
|
2151
|
+
|
|
2152
|
+
auto processed = preprocessor.preprocess(shader_src, defines);
|
|
2153
|
+
|
|
2154
|
+
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
|
|
2155
|
+
decisions->tile_k = tile_k;
|
|
2156
|
+
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
|
|
2157
|
+
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
|
|
2158
|
+
decisions->use_subgroup_matrix = key.use_subgroup_matrix;
|
|
2159
|
+
if (key.use_subgroup_matrix) {
|
|
2160
|
+
decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M;
|
|
2161
|
+
decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N;
|
|
2162
|
+
decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M;
|
|
2163
|
+
decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N;
|
|
2164
|
+
decisions->wg_size = context.max_subgroup_size;
|
|
2165
|
+
} else {
|
|
2166
|
+
decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
|
|
2167
|
+
decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
|
|
2168
|
+
decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N;
|
|
2169
|
+
decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE;
|
|
2170
|
+
}
|
|
2171
|
+
|
|
2172
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2173
|
+
pipeline.context = decisions;
|
|
2174
|
+
mul_mat_fast_pipelines[key] = pipeline;
|
|
2175
|
+
return mul_mat_fast_pipelines[key];
|
|
2176
|
+
}
|
|
2177
|
+
|
|
2178
|
+
webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2179
|
+
auto it = mul_mat_id_gather_pipelines.find(1);
|
|
2180
|
+
if (it != mul_mat_id_gather_pipelines.end()) {
|
|
2181
|
+
return it->second;
|
|
2182
|
+
}
|
|
2183
|
+
std::vector<std::string> defines;
|
|
2184
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
2185
|
+
|
|
2186
|
+
auto processed = preprocessor.preprocess(wgsl_mul_mat_id_gather, defines);
|
|
2187
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
2188
|
+
decisions->wg_size = context.max_wg_size;
|
|
2189
|
+
|
|
2190
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, "mul_mat_id_gather");
|
|
2191
|
+
pipeline.context = decisions;
|
|
2192
|
+
mul_mat_id_gather_pipelines[1] = pipeline;
|
|
2193
|
+
return pipeline;
|
|
2194
|
+
}
|
|
2195
|
+
|
|
2196
|
+
webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2197
|
+
ggml_webgpu_mul_mat_id_pipeline_key key = {};
|
|
2198
|
+
key.src0_type = context.src0->type;
|
|
2199
|
+
key.src1_type = context.src1->type;
|
|
2200
|
+
key.n_experts = context.src0->ne[2];
|
|
2201
|
+
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
|
|
2202
|
+
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
|
2203
|
+
1 :
|
|
2204
|
+
0;
|
|
2205
|
+
|
|
2206
|
+
auto it = mul_mat_id_pipelines.find(key);
|
|
2207
|
+
if (it != mul_mat_id_pipelines.end()) {
|
|
2208
|
+
return it->second;
|
|
2209
|
+
}
|
|
2210
|
+
|
|
2211
|
+
std::vector<std::string> defines;
|
|
2212
|
+
std::string variant = "mul_mat_id";
|
|
2213
|
+
defines.push_back("MUL_MAT_ID");
|
|
2214
|
+
|
|
2215
|
+
// src1 type
|
|
2216
|
+
switch (context.src1->type) {
|
|
2217
|
+
case GGML_TYPE_F32:
|
|
2218
|
+
defines.push_back("SRC1_INNER_TYPE=f32");
|
|
2219
|
+
break;
|
|
2220
|
+
case GGML_TYPE_F16:
|
|
2221
|
+
defines.push_back("SRC1_INNER_TYPE=f16");
|
|
2222
|
+
break;
|
|
2223
|
+
default:
|
|
2224
|
+
GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
|
|
2225
|
+
}
|
|
2226
|
+
|
|
2227
|
+
// src0 type
|
|
2228
|
+
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
|
2229
|
+
const char * src0_name = src0_traits->type_name;
|
|
2230
|
+
|
|
2231
|
+
switch (context.src0->type) {
|
|
2232
|
+
case GGML_TYPE_F32:
|
|
2233
|
+
defines.push_back("SRC0_INNER_TYPE=f32");
|
|
2234
|
+
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
|
2235
|
+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
|
2236
|
+
variant += "_f32";
|
|
2237
|
+
break;
|
|
2238
|
+
case GGML_TYPE_F16:
|
|
2239
|
+
defines.push_back("SRC0_INNER_TYPE=f16");
|
|
2240
|
+
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
|
2241
|
+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
|
2242
|
+
variant += "_f16";
|
|
2243
|
+
break;
|
|
2244
|
+
default:
|
|
2245
|
+
{
|
|
2246
|
+
std::string type_upper = src0_name;
|
|
2247
|
+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
2248
|
+
|
|
2249
|
+
defines.push_back("BYTE_HELPERS");
|
|
2250
|
+
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
|
|
2251
|
+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
|
2252
|
+
defines.push_back("U32_DEQUANT_HELPERS");
|
|
2253
|
+
defines.push_back("SRC0_INNER_TYPE=u32");
|
|
2254
|
+
|
|
2255
|
+
switch (context.src0->type) {
|
|
2256
|
+
case GGML_TYPE_IQ1_S:
|
|
2257
|
+
case GGML_TYPE_IQ1_M:
|
|
2258
|
+
case GGML_TYPE_IQ4_NL:
|
|
2259
|
+
case GGML_TYPE_IQ4_XS:
|
|
2260
|
+
defines.push_back(type_upper + "_GRID");
|
|
2261
|
+
break;
|
|
2262
|
+
case GGML_TYPE_IQ2_XXS:
|
|
2263
|
+
case GGML_TYPE_IQ2_XS:
|
|
2264
|
+
case GGML_TYPE_IQ2_S:
|
|
2265
|
+
case GGML_TYPE_IQ3_XXS:
|
|
2266
|
+
case GGML_TYPE_IQ3_S:
|
|
2267
|
+
defines.push_back(type_upper + "_GRID");
|
|
2268
|
+
defines.push_back(type_upper + "_TABLES");
|
|
2269
|
+
break;
|
|
2270
|
+
case GGML_TYPE_MXFP4:
|
|
2271
|
+
defines.push_back(type_upper + "_LUT");
|
|
2272
|
+
break;
|
|
2273
|
+
default:
|
|
2274
|
+
break;
|
|
2275
|
+
}
|
|
2276
|
+
|
|
2277
|
+
variant += std::string("_") + src0_name;
|
|
2278
|
+
break;
|
|
2279
|
+
}
|
|
2280
|
+
}
|
|
2281
|
+
|
|
2282
|
+
// VEC/SCALAR controls
|
|
2283
|
+
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
|
2284
|
+
|
|
2285
|
+
// mul_mat_id is register-tile only.
|
|
2286
|
+
const uint32_t tile_k =
|
|
2287
|
+
ggml_is_quantized(context.src0->type) ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT;
|
|
2288
|
+
|
|
2289
|
+
// Tiles
|
|
2290
|
+
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
|
2291
|
+
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
|
|
2292
|
+
defines.push_back("TILE_K=" + std::to_string(tile_k) + "u");
|
|
2293
|
+
|
|
2294
|
+
defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
|
|
2295
|
+
defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
|
|
2296
|
+
|
|
2297
|
+
// variant suffix for src1 type
|
|
2298
|
+
variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
|
|
2299
|
+
if (key.vectorized) {
|
|
2300
|
+
variant += "_vectorized";
|
|
2301
|
+
}
|
|
2302
|
+
|
|
2303
|
+
auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines);
|
|
2304
|
+
|
|
2305
|
+
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
|
|
2306
|
+
decisions->tile_k = tile_k;
|
|
2307
|
+
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
|
|
2308
|
+
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
|
|
2309
|
+
decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
|
|
2310
|
+
decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
|
|
2311
|
+
decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N;
|
|
2312
|
+
|
|
2313
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2314
|
+
pipeline.context = decisions;
|
|
2315
|
+
mul_mat_id_pipelines[key] = pipeline;
|
|
2316
|
+
return mul_mat_id_pipelines[key];
|
|
2317
|
+
}
|
|
2318
|
+
|
|
2319
|
+
webgpu_pipeline get_mul_mat_id_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2320
|
+
ggml_webgpu_mul_mat_id_pipeline_key key = {};
|
|
2321
|
+
key.src0_type = context.src0->type;
|
|
2322
|
+
key.src1_type = context.src1->type;
|
|
2323
|
+
key.n_experts = context.src0->ne[2];
|
|
2324
|
+
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
|
|
2325
|
+
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
|
2326
|
+
1 :
|
|
2327
|
+
0;
|
|
2328
|
+
|
|
2329
|
+
auto it = mul_mat_id_vec_pipelines.find(key);
|
|
2330
|
+
if (it != mul_mat_id_vec_pipelines.end()) {
|
|
2331
|
+
return it->second;
|
|
2332
|
+
}
|
|
2333
|
+
|
|
2334
|
+
std::vector<std::string> defines;
|
|
2335
|
+
std::string variant = "mul_mat_id_vec";
|
|
2336
|
+
const char * shader_src = wgsl_mul_mat_id_vec;
|
|
2337
|
+
|
|
2338
|
+
// src1 type
|
|
2339
|
+
switch (context.src1->type) {
|
|
2340
|
+
case GGML_TYPE_F32:
|
|
2341
|
+
defines.push_back("SRC1_INNER_TYPE=f32");
|
|
2342
|
+
break;
|
|
2343
|
+
case GGML_TYPE_F16:
|
|
2344
|
+
defines.push_back("SRC1_INNER_TYPE=f16");
|
|
2345
|
+
break;
|
|
2346
|
+
default:
|
|
2347
|
+
GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
|
|
2348
|
+
}
|
|
2349
|
+
|
|
2350
|
+
// src0 type
|
|
2351
|
+
switch (context.src0->type) {
|
|
2352
|
+
case GGML_TYPE_F32:
|
|
2353
|
+
defines.push_back("SRC0_INNER_TYPE=f32");
|
|
2354
|
+
defines.push_back("MUL_ACC_FLOAT");
|
|
2355
|
+
variant += "_f32";
|
|
2356
|
+
break;
|
|
2357
|
+
case GGML_TYPE_F16:
|
|
2358
|
+
defines.push_back("SRC0_INNER_TYPE=f16");
|
|
2359
|
+
defines.push_back("MUL_ACC_FLOAT");
|
|
2360
|
+
variant += "_f16";
|
|
2361
|
+
break;
|
|
2362
|
+
default:
|
|
2363
|
+
{
|
|
2364
|
+
// Quantized types: use helpers but accumulate in f16
|
|
2365
|
+
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
|
2366
|
+
std::string src0_name = src0_traits->type_name;
|
|
2367
|
+
std::string type_upper = src0_name;
|
|
2368
|
+
variant += "_" + src0_name;
|
|
2369
|
+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
2370
|
+
|
|
2371
|
+
defines.push_back("BYTE_HELPERS");
|
|
2372
|
+
defines.push_back("MUL_ACC_" + type_upper);
|
|
2373
|
+
defines.push_back("U32_DEQUANT_HELPERS");
|
|
2374
|
+
defines.push_back("SRC0_INNER_TYPE=u32");
|
|
2375
|
+
switch (context.src0->type) {
|
|
2376
|
+
case GGML_TYPE_IQ1_S:
|
|
2377
|
+
case GGML_TYPE_IQ1_M:
|
|
2378
|
+
case GGML_TYPE_IQ2_S:
|
|
2379
|
+
case GGML_TYPE_IQ3_S:
|
|
2380
|
+
case GGML_TYPE_IQ4_NL:
|
|
2381
|
+
case GGML_TYPE_IQ4_XS:
|
|
2382
|
+
defines.push_back(type_upper + "_GRID");
|
|
2383
|
+
break;
|
|
2384
|
+
case GGML_TYPE_IQ2_XXS:
|
|
2385
|
+
case GGML_TYPE_IQ2_XS:
|
|
2386
|
+
case GGML_TYPE_IQ3_XXS:
|
|
2387
|
+
defines.push_back(type_upper + "_GRID");
|
|
2388
|
+
defines.push_back(type_upper + "_TABLES");
|
|
2389
|
+
break;
|
|
2390
|
+
case GGML_TYPE_MXFP4:
|
|
2391
|
+
defines.push_back(type_upper + "_LUT");
|
|
2392
|
+
break;
|
|
2393
|
+
default:
|
|
2394
|
+
break;
|
|
2395
|
+
}
|
|
2396
|
+
break;
|
|
2397
|
+
}
|
|
2398
|
+
}
|
|
2399
|
+
|
|
2400
|
+
// VEC/SCALAR controls
|
|
2401
|
+
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
|
2402
|
+
|
|
2403
|
+
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
|
2404
|
+
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
|
|
2405
|
+
|
|
2406
|
+
if (key.src0_type == GGML_TYPE_Q1_0) {
|
|
2407
|
+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
|
2408
|
+
} else if (key.src0_type >= GGML_TYPE_Q2_K) {
|
|
2409
|
+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
|
|
2410
|
+
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
|
|
2411
|
+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
|
2412
|
+
}
|
|
2413
|
+
|
|
2414
|
+
// variant suffix for src1 type
|
|
2415
|
+
variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
|
|
2416
|
+
|
|
2417
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
2418
|
+
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
|
|
2419
|
+
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
|
|
2420
|
+
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
|
|
2421
|
+
if (key.vectorized) {
|
|
2422
|
+
variant += "_vectorized";
|
|
2423
|
+
}
|
|
2424
|
+
|
|
2425
|
+
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
|
|
2426
|
+
|
|
2427
|
+
auto processed = preprocessor.preprocess(shader_src, defines);
|
|
2428
|
+
|
|
2429
|
+
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
|
|
2430
|
+
decisions->wg_size = wg_size;
|
|
2431
|
+
decisions->outputs_per_wg = outputs_per_wg;
|
|
2432
|
+
|
|
2433
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2434
|
+
pipeline.context = decisions;
|
|
2435
|
+
mul_mat_id_vec_pipelines[key] = pipeline;
|
|
2436
|
+
return mul_mat_id_vec_pipelines[key];
|
|
2437
|
+
}
|
|
2438
|
+
|
|
2439
|
+
webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2440
|
+
const bool is_unary = context.dst->op == GGML_OP_UNARY;
|
|
2441
|
+
const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
|
|
2442
|
+
ggml_webgpu_unary_pipeline_key key = {};
|
|
2443
|
+
key.type = context.dst->type;
|
|
2444
|
+
key.op = op;
|
|
2445
|
+
key.is_unary = is_unary;
|
|
2446
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst) || context.dst->op == GGML_OP_FILL;
|
|
2447
|
+
key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0);
|
|
2448
|
+
|
|
2449
|
+
auto it = unary_pipelines.find(key);
|
|
2450
|
+
if (it != unary_pipelines.end()) {
|
|
2451
|
+
return it->second;
|
|
2452
|
+
}
|
|
2453
|
+
|
|
2454
|
+
std::vector<std::string> defines;
|
|
2455
|
+
std::string variant =
|
|
2456
|
+
key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op);
|
|
2457
|
+
defines.push_back(variant);
|
|
2458
|
+
|
|
2459
|
+
switch (key.type) {
|
|
2460
|
+
case GGML_TYPE_F32:
|
|
2461
|
+
defines.push_back("TYPE_F32");
|
|
2462
|
+
variant += "_f32";
|
|
2463
|
+
break;
|
|
2464
|
+
case GGML_TYPE_F16:
|
|
2465
|
+
defines.push_back("TYPE_F16");
|
|
2466
|
+
variant += "_f16";
|
|
2467
|
+
break;
|
|
2468
|
+
default:
|
|
2469
|
+
GGML_ABORT("Unsupported type for unary shader");
|
|
2470
|
+
}
|
|
2471
|
+
|
|
2472
|
+
if (key.inplace) {
|
|
2473
|
+
defines.push_back("INPLACE");
|
|
2474
|
+
variant += "_inplace";
|
|
2475
|
+
}
|
|
2476
|
+
|
|
2477
|
+
if (op == GGML_OP_TRI) {
|
|
2478
|
+
switch (key.ttype) {
|
|
2479
|
+
case GGML_TRI_TYPE_LOWER:
|
|
2480
|
+
defines.push_back("TRI_TYPE_LOWER");
|
|
2481
|
+
variant += "_tri_type_lower";
|
|
2482
|
+
break;
|
|
2483
|
+
case GGML_TRI_TYPE_LOWER_DIAG:
|
|
2484
|
+
defines.push_back("TRI_TYPE_LOWER_DIAG");
|
|
2485
|
+
variant += "_tri_type_lower_diag";
|
|
2486
|
+
break;
|
|
2487
|
+
case GGML_TRI_TYPE_UPPER:
|
|
2488
|
+
defines.push_back("TRI_TYPE_UPPER");
|
|
2489
|
+
variant += "_tri_type_upper";
|
|
2490
|
+
break;
|
|
2491
|
+
case GGML_TRI_TYPE_UPPER_DIAG:
|
|
2492
|
+
defines.push_back("TRI_TYPE_UPPER_DIAG");
|
|
2493
|
+
variant += "_tri_upper_diag";
|
|
2494
|
+
break;
|
|
2495
|
+
default:
|
|
2496
|
+
GGML_ABORT("Unsupported ggml_tri_type for unary shader");
|
|
2497
|
+
}
|
|
2498
|
+
}
|
|
2499
|
+
|
|
2500
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
2501
|
+
|
|
2502
|
+
auto processed = preprocessor.preprocess(wgsl_unary, defines);
|
|
2503
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
2504
|
+
decisions->wg_size = context.max_wg_size;
|
|
2505
|
+
decisions->inplace = key.inplace;
|
|
2506
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2507
|
+
pipeline.context = decisions;
|
|
2508
|
+
unary_pipelines[key] = pipeline;
|
|
2509
|
+
return unary_pipelines[key];
|
|
2510
|
+
}
|
|
2511
|
+
|
|
2512
|
+
webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2513
|
+
ggml_webgpu_rms_norm_mul_pipeline_key key = {};
|
|
2514
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
2515
|
+
key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
|
|
2516
|
+
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
|
|
2517
|
+
|
|
2518
|
+
auto it = rms_norm_mul_pipelines.find(key);
|
|
2519
|
+
if (it != rms_norm_mul_pipelines.end()) {
|
|
2520
|
+
return it->second;
|
|
2521
|
+
}
|
|
2522
|
+
|
|
2523
|
+
std::vector<std::string> defines;
|
|
2524
|
+
std::string op_name = "RMS_NORM_MUL";
|
|
2525
|
+
std::string variant = op_name;
|
|
2526
|
+
|
|
2527
|
+
if (key.inplace) {
|
|
2528
|
+
defines.push_back("INPLACE");
|
|
2529
|
+
variant += "_inplace";
|
|
2530
|
+
} else if (key.overlap) {
|
|
2531
|
+
defines.push_back("OVERLAP");
|
|
2532
|
+
variant += "_overlap";
|
|
2533
|
+
} else if (key.src_overlap) {
|
|
2534
|
+
defines.push_back("SRC_OVERLAP");
|
|
2535
|
+
variant += "_src_overlap";
|
|
2536
|
+
}
|
|
2537
|
+
|
|
2538
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
2539
|
+
|
|
2540
|
+
auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines);
|
|
2541
|
+
auto pipeline_decisions = std::make_shared<ggml_webgpu_rms_norm_mul_shader_decisions>();
|
|
2542
|
+
pipeline_decisions->wg_size = context.max_wg_size;
|
|
2543
|
+
pipeline_decisions->inplace = key.inplace;
|
|
2544
|
+
pipeline_decisions->overlap = key.overlap;
|
|
2545
|
+
pipeline_decisions->src_overlap = key.src_overlap;
|
|
2546
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2547
|
+
pipeline.context = pipeline_decisions;
|
|
2548
|
+
rms_norm_mul_pipelines[key] = pipeline;
|
|
2549
|
+
return rms_norm_mul_pipelines[key];
|
|
2550
|
+
}
|
|
2551
|
+
|
|
2552
|
+
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2553
|
+
ggml_webgpu_binary_pipeline_key key = {};
|
|
2554
|
+
key.type = context.dst->type;
|
|
2555
|
+
key.op = context.dst->op;
|
|
2556
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
2557
|
+
key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst);
|
|
2558
|
+
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
|
|
2559
|
+
|
|
2560
|
+
auto it = binary_pipelines.find(key);
|
|
2561
|
+
if (it != binary_pipelines.end()) {
|
|
2562
|
+
return it->second;
|
|
2563
|
+
}
|
|
2564
|
+
|
|
2565
|
+
std::vector<std::string> defines;
|
|
2566
|
+
std::string op_name = ggml_op_name((ggml_op) key.op);
|
|
2567
|
+
std::string variant = op_name;
|
|
2568
|
+
|
|
2569
|
+
defines.push_back(std::string("OP_") + op_name);
|
|
2570
|
+
|
|
2571
|
+
switch (key.type) {
|
|
2572
|
+
case GGML_TYPE_F32:
|
|
2573
|
+
defines.push_back("TYPE_F32");
|
|
2574
|
+
variant += "_f32";
|
|
2575
|
+
break;
|
|
2576
|
+
case GGML_TYPE_F16:
|
|
2577
|
+
defines.push_back("TYPE_F16");
|
|
2578
|
+
variant += "_f16";
|
|
2579
|
+
break;
|
|
2580
|
+
default:
|
|
2581
|
+
GGML_ABORT("Unsupported type for binary shader");
|
|
2582
|
+
}
|
|
2583
|
+
|
|
2584
|
+
if (key.inplace) {
|
|
2585
|
+
defines.push_back("INPLACE");
|
|
2586
|
+
variant += "_inplace";
|
|
2587
|
+
} else if (key.overlap) {
|
|
2588
|
+
defines.push_back("OVERLAP");
|
|
2589
|
+
variant += "_overlap";
|
|
2590
|
+
} else if (key.src_overlap) {
|
|
2591
|
+
defines.push_back("SRC_OVERLAP");
|
|
2592
|
+
variant += "_src_overlap";
|
|
2593
|
+
}
|
|
2594
|
+
|
|
2595
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
2596
|
+
|
|
2597
|
+
auto processed = preprocessor.preprocess(wgsl_binary, defines);
|
|
2598
|
+
auto pipeline_decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
|
|
2599
|
+
pipeline_decisions->wg_size = context.max_wg_size;
|
|
2600
|
+
pipeline_decisions->inplace = key.inplace;
|
|
2601
|
+
pipeline_decisions->overlap = key.overlap;
|
|
2602
|
+
pipeline_decisions->src_overlap = key.src_overlap;
|
|
2603
|
+
|
|
2604
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2605
|
+
pipeline.context = pipeline_decisions;
|
|
2606
|
+
binary_pipelines[key] = pipeline;
|
|
2607
|
+
return binary_pipelines[key];
|
|
2608
|
+
}
|
|
2609
|
+
|
|
2610
|
+
webgpu_pipeline get_add_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2611
|
+
ggml_webgpu_add_id_pipeline_key key = {};
|
|
2612
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
2613
|
+
|
|
2614
|
+
auto it = add_id_pipelines.find(key);
|
|
2615
|
+
if (it != add_id_pipelines.end()) {
|
|
2616
|
+
return it->second;
|
|
2617
|
+
}
|
|
2618
|
+
|
|
2619
|
+
std::vector<std::string> defines;
|
|
2620
|
+
std::string variant = "add_id";
|
|
2621
|
+
const char * shader_src = wgsl_add_id;
|
|
2622
|
+
|
|
2623
|
+
if (key.inplace) {
|
|
2624
|
+
defines.push_back("INPLACE");
|
|
2625
|
+
variant += "_inplace";
|
|
2626
|
+
}
|
|
2627
|
+
|
|
2628
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
2629
|
+
|
|
2630
|
+
auto processed = preprocessor.preprocess(shader_src, defines);
|
|
2631
|
+
auto pipeline_decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
2632
|
+
pipeline_decisions->wg_size = context.max_wg_size;
|
|
2633
|
+
pipeline_decisions->inplace = key.inplace;
|
|
2634
|
+
|
|
2635
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2636
|
+
pipeline.context = pipeline_decisions;
|
|
2637
|
+
add_id_pipelines[key] = pipeline;
|
|
2638
|
+
return pipeline;
|
|
2639
|
+
}
|
|
2640
|
+
|
|
2641
|
+
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2642
|
+
ggml_webgpu_concat_pipeline_key key = {};
|
|
2643
|
+
key.type = context.dst->type;
|
|
2644
|
+
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
|
|
2645
|
+
|
|
2646
|
+
auto it = concat_pipelines.find(key);
|
|
2647
|
+
if (it != concat_pipelines.end()) {
|
|
2648
|
+
return it->second;
|
|
2649
|
+
}
|
|
2650
|
+
|
|
2651
|
+
std::vector<std::string> defines;
|
|
2652
|
+
std::string variant = "concat";
|
|
2653
|
+
|
|
2654
|
+
switch (key.type) {
|
|
2655
|
+
case GGML_TYPE_F32:
|
|
2656
|
+
defines.push_back("TYPE_F32");
|
|
2657
|
+
variant += "_f32";
|
|
2658
|
+
break;
|
|
2659
|
+
case GGML_TYPE_I32:
|
|
2660
|
+
defines.push_back("TYPE_I32");
|
|
2661
|
+
variant += "_i32";
|
|
2662
|
+
break;
|
|
2663
|
+
default:
|
|
2664
|
+
GGML_ABORT("Unsupported type for concat shader");
|
|
2665
|
+
}
|
|
2666
|
+
|
|
2667
|
+
if (key.src_overlap) {
|
|
2668
|
+
defines.push_back("SRC_OVERLAP");
|
|
2669
|
+
variant += "_src_overlap";
|
|
2670
|
+
}
|
|
2671
|
+
|
|
2672
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
2673
|
+
|
|
2674
|
+
auto processed = preprocessor.preprocess(wgsl_concat, defines);
|
|
2675
|
+
auto decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
|
|
2676
|
+
decisions->wg_size = context.max_wg_size;
|
|
2677
|
+
decisions->src_overlap = key.src_overlap;
|
|
2678
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2679
|
+
pipeline.context = decisions;
|
|
2680
|
+
concat_pipelines[key] = pipeline;
|
|
2681
|
+
return concat_pipelines[key];
|
|
2682
|
+
}
|
|
2683
|
+
|
|
2684
|
+
webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2685
|
+
ggml_webgpu_repeat_pipeline_key key = {};
|
|
2686
|
+
key.type = context.dst->type;
|
|
2687
|
+
|
|
2688
|
+
auto it = repeat_pipelines.find(key);
|
|
2689
|
+
if (it != repeat_pipelines.end()) {
|
|
2690
|
+
return it->second;
|
|
2691
|
+
}
|
|
2692
|
+
|
|
2693
|
+
std::vector<std::string> defines;
|
|
2694
|
+
std::string variant = "repeat";
|
|
2695
|
+
|
|
2696
|
+
switch (key.type) {
|
|
2697
|
+
case GGML_TYPE_F32:
|
|
2698
|
+
defines.push_back("TYPE_F32");
|
|
2699
|
+
variant += "_f32";
|
|
2700
|
+
break;
|
|
2701
|
+
case GGML_TYPE_I32:
|
|
2702
|
+
defines.push_back("TYPE_I32");
|
|
2703
|
+
variant += "_i32";
|
|
2704
|
+
break;
|
|
2705
|
+
case GGML_TYPE_I16:
|
|
2706
|
+
defines.push_back("TYPE_I16");
|
|
2707
|
+
variant += "_i16";
|
|
2708
|
+
break;
|
|
2709
|
+
default:
|
|
2710
|
+
GGML_ABORT("Unsupported type for repeat shader");
|
|
2711
|
+
}
|
|
2712
|
+
|
|
2713
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
2714
|
+
|
|
2715
|
+
auto processed = preprocessor.preprocess(wgsl_repeat, defines);
|
|
2716
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
2717
|
+
decisions->wg_size = context.max_wg_size;
|
|
2718
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2719
|
+
pipeline.context = decisions;
|
|
2720
|
+
repeat_pipelines[key] = pipeline;
|
|
2721
|
+
return repeat_pipelines[key];
|
|
2722
|
+
}
|
|
2723
|
+
|
|
2724
|
+
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2725
|
+
const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(
|
|
2726
|
+
context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2);
|
|
2727
|
+
ggml_webgpu_flash_attn_decisions decisions = {};
|
|
2728
|
+
decisions.use_sg_matrix = can_use_subgroup_matrix;
|
|
2729
|
+
decisions.q_tile = decisions.use_sg_matrix ? context.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
|
|
2730
|
+
|
|
2731
|
+
ggml_webgpu_flash_attn_pipeline_key key = {};
|
|
2732
|
+
key.common =
|
|
2733
|
+
ggml_webgpu_flash_attn_make_common_pipeline_key(context, decisions.use_sg_matrix ? context.sg_mat_k : 1u);
|
|
2734
|
+
key.common.kv_direct = decisions.use_sg_matrix && key.common.kv_direct;
|
|
2735
|
+
key.use_sg_matrix = decisions.use_sg_matrix;
|
|
2736
|
+
|
|
2737
|
+
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(
|
|
2738
|
+
context.wg_mem_limit_bytes, decisions.q_tile, decisions.use_sg_matrix ? context.sg_mat_n : 1u,
|
|
2739
|
+
key.common.head_dim_qk, key.common.head_dim_v, key.common.has_mask, key.common.kv_direct);
|
|
2740
|
+
GGML_ASSERT(max_kv_tile > 0);
|
|
2741
|
+
|
|
2742
|
+
decisions.kv_tile = decisions.use_sg_matrix ?
|
|
2743
|
+
std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES) :
|
|
2744
|
+
std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile);
|
|
2745
|
+
decisions.wg_size =
|
|
2746
|
+
decisions.use_sg_matrix ?
|
|
2747
|
+
std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE) :
|
|
2748
|
+
std::min(context.max_wg_size, std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE,
|
|
2749
|
+
GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size));
|
|
2750
|
+
|
|
2751
|
+
if (key.common.kv_direct) {
|
|
2752
|
+
decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
|
2753
|
+
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
|
2754
|
+
decisions.kv_tile -= decisions.use_sg_matrix ? context.sg_mat_n : context.min_subgroup_size;
|
|
2755
|
+
}
|
|
2756
|
+
}
|
|
2757
|
+
|
|
2758
|
+
auto it = flash_attn_pipelines.find(key);
|
|
2759
|
+
if (it != flash_attn_pipelines.end()) {
|
|
2760
|
+
return it->second;
|
|
2761
|
+
}
|
|
2762
|
+
|
|
2763
|
+
std::string variant = decisions.use_sg_matrix ? "flash_attn" : "flash_attn_tile";
|
|
2764
|
+
std::vector<std::string> defines = ggml_webgpu_flash_attn_common_defines(key.common, variant, decisions.q_tile,
|
|
2765
|
+
decisions.kv_tile, decisions.wg_size);
|
|
2766
|
+
const char * shader_src = nullptr;
|
|
2767
|
+
if (!key.use_sg_matrix) {
|
|
2768
|
+
shader_src = wgsl_flash_attn_tile;
|
|
2769
|
+
defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u");
|
|
2770
|
+
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
|
|
2771
|
+
variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" +
|
|
2772
|
+
std::to_string(context.max_subgroup_size);
|
|
2773
|
+
} else {
|
|
2774
|
+
shader_src = wgsl_flash_attn;
|
|
2775
|
+
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
|
2776
|
+
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
|
2777
|
+
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
|
2778
|
+
}
|
|
2779
|
+
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
|
|
2780
|
+
webgpu_pipeline pipeline =
|
|
2781
|
+
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
|
|
2782
|
+
pipeline.context = pipeline_decisions;
|
|
2783
|
+
flash_attn_pipelines[key] = pipeline;
|
|
2784
|
+
return flash_attn_pipelines[key];
|
|
2785
|
+
}
|
|
2786
|
+
|
|
2787
|
+
webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2788
|
+
ggml_webgpu_flash_attn_vec_pipeline_key key = {};
|
|
2789
|
+
key.common = ggml_webgpu_flash_attn_make_common_pipeline_key(context, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH);
|
|
2790
|
+
|
|
2791
|
+
auto it = flash_attn_vec_pipelines.find(key);
|
|
2792
|
+
if (it != flash_attn_vec_pipelines.end()) {
|
|
2793
|
+
return it->second;
|
|
2794
|
+
}
|
|
2795
|
+
|
|
2796
|
+
ggml_webgpu_flash_attn_vec_decisions decisions = {};
|
|
2797
|
+
decisions.kv_tile =
|
|
2798
|
+
ggml_webgpu_flash_attn_get_vec_kv_tile(context.wg_mem_limit_bytes, key.common.head_dim_qk,
|
|
2799
|
+
key.common.head_dim_v, key.common.has_mask, key.common.kv_direct);
|
|
2800
|
+
decisions.wg_size = context.max_subgroup_size;
|
|
2801
|
+
|
|
2802
|
+
std::string variant = "flash_attn_vec";
|
|
2803
|
+
std::vector<std::string> defines =
|
|
2804
|
+
ggml_webgpu_flash_attn_common_defines(key.common, variant, 1u, decisions.kv_tile, decisions.wg_size);
|
|
2805
|
+
if (key.common.has_mask) {
|
|
2806
|
+
defines.push_back("BLK");
|
|
2807
|
+
variant.resize(variant.size() - (sizeof("_mask") - 1));
|
|
2808
|
+
variant += "_mask_blk";
|
|
2809
|
+
}
|
|
2810
|
+
uint32_t vec_ne = 1u;
|
|
2811
|
+
if (key.common.k_type == GGML_TYPE_F16 && key.common.v_type == GGML_TYPE_F16 &&
|
|
2812
|
+
key.common.head_dim_qk == key.common.head_dim_v) {
|
|
2813
|
+
switch (key.common.head_dim_qk) {
|
|
2814
|
+
case 64:
|
|
2815
|
+
case 192:
|
|
2816
|
+
case 576:
|
|
2817
|
+
vec_ne = 2u;
|
|
2818
|
+
break;
|
|
2819
|
+
case 96:
|
|
2820
|
+
vec_ne = 4u;
|
|
2821
|
+
break;
|
|
2822
|
+
default:
|
|
2823
|
+
break;
|
|
2824
|
+
}
|
|
2825
|
+
}
|
|
2826
|
+
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
|
|
2827
|
+
|
|
2828
|
+
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>(decisions);
|
|
2829
|
+
webgpu_pipeline pipeline =
|
|
2830
|
+
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant);
|
|
2831
|
+
pipeline.context = pipeline_decisions;
|
|
2832
|
+
flash_attn_vec_pipelines[key] = pipeline;
|
|
2833
|
+
return flash_attn_vec_pipelines[key];
|
|
2834
|
+
}
|
|
2835
|
+
|
|
2836
|
+
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) {
|
|
2837
|
+
ggml_webgpu_flash_attn_blk_pipeline_key key = {};
|
|
2838
|
+
key.kv_tile = kv_tile;
|
|
2839
|
+
auto it = flash_attn_blk_pipelines.find(key);
|
|
2840
|
+
if (it != flash_attn_blk_pipelines.end()) {
|
|
2841
|
+
return it->second;
|
|
2842
|
+
}
|
|
2843
|
+
|
|
2844
|
+
std::vector<std::string> defines;
|
|
2845
|
+
std::string variant = "flash_attn_vec_blk";
|
|
2846
|
+
|
|
2847
|
+
defines.push_back(std::string("KV_TILE=") + std::to_string(key.kv_tile));
|
|
2848
|
+
variant += std::string("_kvt") + std::to_string(key.kv_tile);
|
|
2849
|
+
|
|
2850
|
+
uint32_t wg_size = 1;
|
|
2851
|
+
while ((wg_size << 1) <= context.max_wg_size) {
|
|
2852
|
+
wg_size <<= 1;
|
|
2853
|
+
}
|
|
2854
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
2855
|
+
variant += std::string("_wg") + std::to_string(wg_size);
|
|
2856
|
+
|
|
2857
|
+
webgpu_pipeline pipeline =
|
|
2858
|
+
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_blk, defines), variant);
|
|
2859
|
+
flash_attn_blk_pipelines[key] = pipeline;
|
|
2860
|
+
return flash_attn_blk_pipelines[key];
|
|
2861
|
+
}
|
|
2862
|
+
|
|
2863
|
+
webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2864
|
+
ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {};
|
|
2865
|
+
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
|
2866
|
+
key.dst_type = context.dst->type;
|
|
2867
|
+
key.wg_size = context.max_wg_size;
|
|
2868
|
+
auto it = flash_attn_vec_reduce_pipelines.find(key);
|
|
2869
|
+
if (it != flash_attn_vec_reduce_pipelines.end()) {
|
|
2870
|
+
return it->second;
|
|
2871
|
+
}
|
|
2872
|
+
|
|
2873
|
+
std::vector<std::string> defines;
|
|
2874
|
+
std::string variant = "flash_attn_vec_reduce";
|
|
2875
|
+
|
|
2876
|
+
switch (key.dst_type) {
|
|
2877
|
+
case GGML_TYPE_F32:
|
|
2878
|
+
defines.push_back("DST_F32");
|
|
2879
|
+
break;
|
|
2880
|
+
case GGML_TYPE_F16:
|
|
2881
|
+
defines.push_back("DST_F16");
|
|
2882
|
+
break;
|
|
2883
|
+
default:
|
|
2884
|
+
GGML_ABORT("Unsupported dst type for flash attention vec reduce shader");
|
|
2885
|
+
}
|
|
2886
|
+
variant += std::string("_dst") + ggml_type_name(key.dst_type);
|
|
2887
|
+
|
|
2888
|
+
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
|
2889
|
+
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
|
2890
|
+
|
|
2891
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
2892
|
+
variant += std::string("_wg") + std::to_string(context.max_wg_size);
|
|
2893
|
+
|
|
2894
|
+
webgpu_pipeline pipeline =
|
|
2895
|
+
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_reduce, defines), variant);
|
|
2896
|
+
flash_attn_vec_reduce_pipelines[key] = pipeline;
|
|
2897
|
+
return flash_attn_vec_reduce_pipelines[key];
|
|
2898
|
+
}
|
|
2899
|
+
|
|
2900
|
+
webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2901
|
+
ggml_webgpu_cpy_pipeline_key key = {};
|
|
2902
|
+
key.src_type = context.src0->type;
|
|
2903
|
+
key.dst_type = context.dst->type;
|
|
2904
|
+
|
|
2905
|
+
auto it = cpy_pipelines.find(key);
|
|
2906
|
+
if (it != cpy_pipelines.end()) {
|
|
2907
|
+
return it->second;
|
|
2908
|
+
}
|
|
2909
|
+
|
|
2910
|
+
std::vector<std::string> defines;
|
|
2911
|
+
std::string variant = "cpy";
|
|
2912
|
+
|
|
2913
|
+
switch (key.src_type) {
|
|
2914
|
+
case GGML_TYPE_F32:
|
|
2915
|
+
defines.push_back("SRC_F32");
|
|
2916
|
+
variant += "_f32";
|
|
2917
|
+
break;
|
|
2918
|
+
case GGML_TYPE_F16:
|
|
2919
|
+
defines.push_back("SRC_F16");
|
|
2920
|
+
variant += "_f16";
|
|
2921
|
+
break;
|
|
2922
|
+
default:
|
|
2923
|
+
GGML_ABORT("Unsupported src type for cpy shader");
|
|
2924
|
+
}
|
|
2925
|
+
|
|
2926
|
+
switch (key.dst_type) {
|
|
2927
|
+
case GGML_TYPE_F32:
|
|
2928
|
+
defines.push_back("DST_F32");
|
|
2929
|
+
variant += "_f32";
|
|
2930
|
+
break;
|
|
2931
|
+
case GGML_TYPE_F16:
|
|
2932
|
+
defines.push_back("DST_F16");
|
|
2933
|
+
variant += "_f16";
|
|
2934
|
+
break;
|
|
2935
|
+
case GGML_TYPE_I32:
|
|
2936
|
+
defines.push_back("DST_I32");
|
|
2937
|
+
variant += "_i32";
|
|
2938
|
+
break;
|
|
2939
|
+
default:
|
|
2940
|
+
GGML_ABORT("Unsupported dst type for cpy shader");
|
|
2941
|
+
}
|
|
2942
|
+
|
|
2943
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
2944
|
+
|
|
2945
|
+
auto processed = preprocessor.preprocess(wgsl_cpy, defines);
|
|
2946
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
2947
|
+
decisions->wg_size = context.max_wg_size;
|
|
2948
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
2949
|
+
pipeline.context = decisions;
|
|
2950
|
+
cpy_pipelines[key] = pipeline;
|
|
2951
|
+
return cpy_pipelines[key];
|
|
2952
|
+
}
|
|
2953
|
+
|
|
2954
|
+
webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
2955
|
+
ggml_webgpu_glu_pipeline_key key = {};
|
|
2956
|
+
key.glu_op = ggml_get_glu_op(context.dst);
|
|
2957
|
+
key.type = context.dst->type;
|
|
2958
|
+
key.split = (context.src1 != nullptr);
|
|
2959
|
+
|
|
2960
|
+
auto it = glu_pipelines.find(key);
|
|
2961
|
+
if (it != glu_pipelines.end()) {
|
|
2962
|
+
return it->second;
|
|
2963
|
+
}
|
|
2964
|
+
|
|
2965
|
+
std::vector<std::string> defines;
|
|
2966
|
+
std::string variant = "glu";
|
|
2967
|
+
|
|
2968
|
+
switch (key.glu_op) {
|
|
2969
|
+
case GGML_GLU_OP_REGLU:
|
|
2970
|
+
defines.push_back("OP_REGLU");
|
|
2971
|
+
variant += "_reglu";
|
|
2972
|
+
break;
|
|
2973
|
+
case GGML_GLU_OP_GEGLU:
|
|
2974
|
+
defines.push_back("OP_GEGLU");
|
|
2975
|
+
variant += "_geglu";
|
|
2976
|
+
break;
|
|
2977
|
+
case GGML_GLU_OP_SWIGLU:
|
|
2978
|
+
defines.push_back("OP_SWIGLU");
|
|
2979
|
+
variant += "_swiglu";
|
|
2980
|
+
break;
|
|
2981
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
2982
|
+
defines.push_back("OP_SWIGLU_OAI");
|
|
2983
|
+
variant += "_swiglu_oai";
|
|
2984
|
+
break;
|
|
2985
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
2986
|
+
defines.push_back("OP_GEGLU_ERF");
|
|
2987
|
+
variant += "_geglu_erf";
|
|
2988
|
+
break;
|
|
2989
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
2990
|
+
defines.push_back("OP_GEGLU_QUICK");
|
|
2991
|
+
variant += "_geglu_quick";
|
|
2992
|
+
break;
|
|
2993
|
+
default:
|
|
2994
|
+
GGML_ABORT("Unsupported GLU op");
|
|
2995
|
+
}
|
|
2996
|
+
switch (key.type) {
|
|
2997
|
+
case GGML_TYPE_F32:
|
|
2998
|
+
defines.push_back("TYPE_F32");
|
|
2999
|
+
variant += "_f32";
|
|
3000
|
+
break;
|
|
3001
|
+
case GGML_TYPE_F16:
|
|
3002
|
+
defines.push_back("TYPE_F16");
|
|
3003
|
+
variant += "_f16";
|
|
3004
|
+
break;
|
|
3005
|
+
default:
|
|
3006
|
+
GGML_ABORT("Unsupported type for GLU shader");
|
|
3007
|
+
}
|
|
3008
|
+
|
|
3009
|
+
if (key.split) {
|
|
3010
|
+
variant += "_split";
|
|
3011
|
+
} else {
|
|
3012
|
+
defines.push_back("NO_SPLIT");
|
|
3013
|
+
}
|
|
3014
|
+
|
|
3015
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
3016
|
+
|
|
3017
|
+
auto processed = preprocessor.preprocess(wgsl_glu, defines);
|
|
3018
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
3019
|
+
decisions->wg_size = context.max_wg_size;
|
|
3020
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
3021
|
+
pipeline.context = decisions;
|
|
3022
|
+
glu_pipelines[key] = pipeline;
|
|
3023
|
+
return glu_pipelines[key];
|
|
3024
|
+
}
|
|
3025
|
+
|
|
3026
|
+
webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
3027
|
+
ggml_webgpu_rope_pipeline_key key = {};
|
|
3028
|
+
key.type = context.dst->type;
|
|
3029
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
3030
|
+
key.has_ff = (context.src2 != nullptr);
|
|
3031
|
+
|
|
3032
|
+
auto it = rope_pipelines.find(key);
|
|
3033
|
+
if (it != rope_pipelines.end()) {
|
|
3034
|
+
return it->second;
|
|
3035
|
+
}
|
|
3036
|
+
|
|
3037
|
+
std::vector<std::string> defines;
|
|
3038
|
+
std::string variant = "rope";
|
|
3039
|
+
|
|
3040
|
+
switch (key.type) {
|
|
3041
|
+
case GGML_TYPE_F32:
|
|
3042
|
+
defines.push_back("TYPE_F32");
|
|
3043
|
+
variant += "_f32";
|
|
3044
|
+
break;
|
|
3045
|
+
case GGML_TYPE_F16:
|
|
3046
|
+
defines.push_back("TYPE_F16");
|
|
3047
|
+
variant += "_f16";
|
|
3048
|
+
break;
|
|
3049
|
+
default:
|
|
3050
|
+
GGML_ABORT("Unsupported type for ROPE shader");
|
|
3051
|
+
}
|
|
3052
|
+
|
|
3053
|
+
if (key.inplace) {
|
|
3054
|
+
defines.push_back("INPLACE");
|
|
3055
|
+
variant += "_inplace";
|
|
3056
|
+
}
|
|
3057
|
+
|
|
3058
|
+
if (key.has_ff) {
|
|
3059
|
+
defines.push_back("FF_FUNC");
|
|
3060
|
+
variant += "_ff";
|
|
3061
|
+
}
|
|
3062
|
+
|
|
3063
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
3064
|
+
|
|
3065
|
+
auto processed = preprocessor.preprocess(wgsl_rope, defines);
|
|
3066
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
3067
|
+
decisions->wg_size = context.max_wg_size;
|
|
3068
|
+
decisions->inplace = key.inplace;
|
|
3069
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
3070
|
+
pipeline.context = decisions;
|
|
3071
|
+
rope_pipelines[key] = pipeline;
|
|
3072
|
+
return rope_pipelines[key];
|
|
3073
|
+
}
|
|
3074
|
+
|
|
3075
|
+
webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
3076
|
+
ggml_webgpu_soft_max_pipeline_key key = {};
|
|
3077
|
+
key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32;
|
|
3078
|
+
key.has_mask = (context.src1 != nullptr);
|
|
3079
|
+
key.has_sink = (context.src2 != nullptr);
|
|
3080
|
+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
|
|
3081
|
+
|
|
3082
|
+
auto it = soft_max_pipelines.find(key);
|
|
3083
|
+
if (it != soft_max_pipelines.end()) {
|
|
3084
|
+
return it->second;
|
|
3085
|
+
}
|
|
3086
|
+
|
|
3087
|
+
std::vector<std::string> defines;
|
|
3088
|
+
std::string variant = "soft_max";
|
|
3089
|
+
|
|
3090
|
+
if (key.has_mask) {
|
|
3091
|
+
defines.push_back("HAS_MASK");
|
|
3092
|
+
switch (key.mask_type) {
|
|
3093
|
+
case GGML_TYPE_F32:
|
|
3094
|
+
defines.push_back("MASK_F32");
|
|
3095
|
+
variant += "_mask_f32";
|
|
3096
|
+
break;
|
|
3097
|
+
case GGML_TYPE_F16:
|
|
3098
|
+
defines.push_back("MASK_F16");
|
|
3099
|
+
variant += "_mask_f16";
|
|
3100
|
+
break;
|
|
3101
|
+
default:
|
|
3102
|
+
GGML_ABORT("Unsupported type for SOFT_MAX shader");
|
|
3103
|
+
}
|
|
3104
|
+
}
|
|
3105
|
+
|
|
3106
|
+
if (key.has_sink) {
|
|
3107
|
+
defines.push_back("HAS_SINK");
|
|
3108
|
+
variant += "_sink";
|
|
3109
|
+
}
|
|
3110
|
+
|
|
3111
|
+
if (key.inplace) {
|
|
3112
|
+
defines.push_back("INPLACE");
|
|
3113
|
+
variant += "_inplace";
|
|
3114
|
+
}
|
|
3115
|
+
|
|
3116
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
3117
|
+
|
|
3118
|
+
auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
|
|
3119
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
3120
|
+
decisions->wg_size = context.max_wg_size;
|
|
3121
|
+
decisions->inplace = key.inplace;
|
|
3122
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
3123
|
+
pipeline.context = decisions;
|
|
3124
|
+
soft_max_pipelines[key] = pipeline;
|
|
3125
|
+
return soft_max_pipelines[key];
|
|
3126
|
+
}
|
|
3127
|
+
|
|
3128
|
+
webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
3129
|
+
ggml_webgpu_conv2d_pipeline_key key = {};
|
|
3130
|
+
key.weight_type = context.src0->type;
|
|
3131
|
+
key.input_type = context.src1->type;
|
|
3132
|
+
key.output_type = context.dst->type;
|
|
3133
|
+
|
|
3134
|
+
auto it = conv2d_pipelines.find(key);
|
|
3135
|
+
if (it != conv2d_pipelines.end()) {
|
|
3136
|
+
return it->second;
|
|
3137
|
+
}
|
|
3138
|
+
|
|
3139
|
+
std::vector<std::string> defines;
|
|
3140
|
+
std::string variant = "conv_2d";
|
|
3141
|
+
|
|
3142
|
+
auto push_type_defines = [&](const char * prefix, ggml_type type) {
|
|
3143
|
+
std::string s_prefix = prefix;
|
|
3144
|
+
if (type == GGML_TYPE_F32) {
|
|
3145
|
+
defines.push_back(s_prefix + "_F32");
|
|
3146
|
+
} else if (type == GGML_TYPE_F16) {
|
|
3147
|
+
defines.push_back(s_prefix + "_F16");
|
|
3148
|
+
} else {
|
|
3149
|
+
GGML_ABORT("Unsupported type for CONV_2D shader");
|
|
3150
|
+
}
|
|
3151
|
+
};
|
|
3152
|
+
|
|
3153
|
+
push_type_defines("WEIGHT", key.weight_type);
|
|
3154
|
+
push_type_defines("INPUT", key.input_type);
|
|
3155
|
+
push_type_defines("OUTPUT", key.output_type);
|
|
3156
|
+
|
|
3157
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
3158
|
+
|
|
3159
|
+
auto processed = preprocessor.preprocess(wgsl_conv2d, defines);
|
|
3160
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
3161
|
+
decisions->wg_size = context.max_wg_size;
|
|
3162
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
3163
|
+
pipeline.context = decisions;
|
|
3164
|
+
conv2d_pipelines[key] = pipeline;
|
|
3165
|
+
return conv2d_pipelines[key];
|
|
3166
|
+
}
|
|
3167
|
+
|
|
3168
|
+
webgpu_pipeline get_im2col_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
3169
|
+
ggml_webgpu_im2col_pipeline_key key = {};
|
|
3170
|
+
key.input_type = context.src1->type;
|
|
3171
|
+
key.output_type = context.dst->type;
|
|
3172
|
+
|
|
3173
|
+
auto it = im2col_pipelines.find(key);
|
|
3174
|
+
if (it != im2col_pipelines.end()) {
|
|
3175
|
+
return it->second;
|
|
3176
|
+
}
|
|
3177
|
+
|
|
3178
|
+
std::vector<std::string> defines;
|
|
3179
|
+
std::string variant = "im2col";
|
|
3180
|
+
|
|
3181
|
+
auto push_type_defines = [&](const char * prefix, ggml_type type) {
|
|
3182
|
+
std::string s_prefix = prefix;
|
|
3183
|
+
if (type == GGML_TYPE_F32) {
|
|
3184
|
+
defines.push_back(s_prefix + "_F32");
|
|
3185
|
+
} else if (type == GGML_TYPE_F16) {
|
|
3186
|
+
defines.push_back(s_prefix + "_F16");
|
|
3187
|
+
} else {
|
|
3188
|
+
GGML_ABORT("Unsupported type for IM2COL shader");
|
|
3189
|
+
}
|
|
3190
|
+
};
|
|
3191
|
+
|
|
3192
|
+
push_type_defines("INPUT", key.input_type);
|
|
3193
|
+
push_type_defines("OUTPUT", key.output_type);
|
|
3194
|
+
|
|
3195
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
3196
|
+
|
|
3197
|
+
auto processed = preprocessor.preprocess(wgsl_im2col, defines);
|
|
3198
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
3199
|
+
decisions->wg_size = context.max_wg_size;
|
|
3200
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
3201
|
+
pipeline.context = decisions;
|
|
3202
|
+
im2col_pipelines[key] = pipeline;
|
|
3203
|
+
return im2col_pipelines[key];
|
|
3204
|
+
}
|
|
3205
|
+
|
|
3206
|
+
webgpu_pipeline get_upscale_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
3207
|
+
const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(context.dst, 0);
|
|
3208
|
+
const uint32_t base_mode = mode_flags & 0xFFu;
|
|
3209
|
+
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS) != 0u;
|
|
3210
|
+
|
|
3211
|
+
ggml_webgpu_upscale_pipeline_key key = {};
|
|
3212
|
+
key.input_type = context.src0->type;
|
|
3213
|
+
key.output_type = context.dst->type;
|
|
3214
|
+
key.base_mode = base_mode;
|
|
3215
|
+
key.antialias = antialias;
|
|
3216
|
+
|
|
3217
|
+
auto it = upscale_pipelines.find(key);
|
|
3218
|
+
if (it != upscale_pipelines.end()) {
|
|
3219
|
+
return it->second;
|
|
3220
|
+
}
|
|
3221
|
+
|
|
3222
|
+
std::vector<std::string> defines;
|
|
3223
|
+
std::string variant = "upscale";
|
|
3224
|
+
|
|
3225
|
+
if (key.input_type == GGML_TYPE_F16) {
|
|
3226
|
+
defines.push_back("SRC_F16");
|
|
3227
|
+
variant += "_src_f16";
|
|
3228
|
+
} else {
|
|
3229
|
+
variant += "_src_f32";
|
|
3230
|
+
}
|
|
3231
|
+
|
|
3232
|
+
if (key.output_type == GGML_TYPE_F16) {
|
|
3233
|
+
defines.push_back("DST_F16");
|
|
3234
|
+
variant += "_dst_f16";
|
|
3235
|
+
} else {
|
|
3236
|
+
variant += "_dst_f32";
|
|
3237
|
+
}
|
|
3238
|
+
|
|
3239
|
+
switch (base_mode) {
|
|
3240
|
+
case GGML_SCALE_MODE_NEAREST:
|
|
3241
|
+
defines.push_back("NEAREST");
|
|
3242
|
+
variant += "_nearest";
|
|
3243
|
+
break;
|
|
3244
|
+
case GGML_SCALE_MODE_BILINEAR:
|
|
3245
|
+
defines.push_back("BILINEAR");
|
|
3246
|
+
variant += "_bilinear";
|
|
3247
|
+
break;
|
|
3248
|
+
case GGML_SCALE_MODE_BICUBIC:
|
|
3249
|
+
defines.push_back("BICUBIC");
|
|
3250
|
+
variant += "_bicubic";
|
|
3251
|
+
break;
|
|
3252
|
+
default:
|
|
3253
|
+
GGML_ABORT("Unsupported upscale mode");
|
|
3254
|
+
}
|
|
3255
|
+
|
|
3256
|
+
if (antialias) {
|
|
3257
|
+
defines.push_back("ANTIALIAS");
|
|
3258
|
+
variant += "_aa";
|
|
3259
|
+
}
|
|
3260
|
+
|
|
3261
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
3262
|
+
|
|
3263
|
+
auto processed = preprocessor.preprocess(wgsl_upscale, defines);
|
|
3264
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
3265
|
+
decisions->wg_size = context.max_wg_size;
|
|
3266
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
3267
|
+
pipeline.context = decisions;
|
|
3268
|
+
upscale_pipelines[key] = pipeline;
|
|
3269
|
+
return upscale_pipelines[key];
|
|
3270
|
+
}
|
|
3271
|
+
|
|
3272
|
+
private:
|
|
3273
|
+
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
|
|
3274
|
+
std::string shader_code,
|
|
3275
|
+
std::string label) {
|
|
3276
|
+
wgpu::ShaderSourceWGSL shader_source;
|
|
3277
|
+
shader_source.code = shader_code.c_str();
|
|
3278
|
+
|
|
3279
|
+
wgpu::ShaderModuleDescriptor shader_desc;
|
|
3280
|
+
shader_desc.nextInChain = &shader_source;
|
|
3281
|
+
|
|
3282
|
+
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
|
|
3283
|
+
|
|
3284
|
+
wgpu::ComputePipelineDescriptor pipeline_desc;
|
|
3285
|
+
pipeline_desc.label = label.c_str();
|
|
3286
|
+
pipeline_desc.compute.module = shader_module;
|
|
3287
|
+
pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
|
|
3288
|
+
pipeline_desc.layout = nullptr; // nullptr means auto layout
|
|
3289
|
+
return { device.CreateComputePipeline(&pipeline_desc), label };
|
|
3290
|
+
}
|
|
3291
|
+
};
|
|
3292
|
+
|
|
169
3293
|
#endif // GGML_WEBGPU_SHADER_LIB_HPP
|