whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -1,20 +1,273 @@
|
|
|
1
1
|
#ifndef GGML_WEBGPU_SHADER_LIB_HPP
|
|
2
2
|
#define GGML_WEBGPU_SHADER_LIB_HPP
|
|
3
3
|
|
|
4
|
+
#include "ggml-wgsl-shaders.hpp"
|
|
4
5
|
#include "ggml.h"
|
|
5
6
|
#include "pre_wgsl.hpp"
|
|
6
7
|
|
|
8
|
+
#include <webgpu/webgpu_cpp.h>
|
|
9
|
+
|
|
10
|
+
#include <algorithm>
|
|
11
|
+
#include <memory>
|
|
7
12
|
#include <string>
|
|
13
|
+
#include <unordered_map>
|
|
8
14
|
#include <vector>
|
|
9
15
|
|
|
10
16
|
#define GGML_WEBGPU_F16_SIZE_BYTES 2
|
|
11
17
|
#define GGML_WEBGPU_F32_SIZE_BYTES 4
|
|
18
|
+
#define GGML_WEBGPU_I32_SIZE_BYTES 4
|
|
12
19
|
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
|
|
13
20
|
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
|
|
14
21
|
// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
|
|
15
22
|
#define GGML_WEBGPU_KV_SEQ_PAD 256u
|
|
16
23
|
|
|
17
|
-
|
|
24
|
+
#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
|
|
25
|
+
|
|
26
|
+
// Matrix multiplication parameters
|
|
27
|
+
|
|
28
|
+
// Register tiling parameters
|
|
29
|
+
#define WEBGPU_MUL_MAT_TILE_M 8
|
|
30
|
+
#define WEBGPU_MUL_MAT_TILE_N 8
|
|
31
|
+
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
|
|
32
|
+
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
|
|
33
|
+
#define WEBGPU_MUL_MAT_TILE_K 32
|
|
34
|
+
|
|
35
|
+
// Subgroup matrix parameters
|
|
36
|
+
// The number of subgroups in the M dimension
|
|
37
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
|
|
38
|
+
// The number of subgroups in the N dimension
|
|
39
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_N 2
|
|
40
|
+
// The number of subgroup matrices each subgroup accumulates over
|
|
41
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
|
|
42
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
|
43
|
+
|
|
44
|
+
// Matrix-vector multiplication parameters
|
|
45
|
+
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
|
46
|
+
|
|
47
|
+
// Must be multiple of 4 to work with vectorized paths, and must divide
|
|
48
|
+
// mul_mat_vec wg size
|
|
49
|
+
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
|
|
50
|
+
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
|
|
51
|
+
|
|
52
|
+
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
|
|
53
|
+
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
|
|
54
|
+
|
|
55
|
+
// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
|
|
56
|
+
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
|
|
57
|
+
// Requires at least two (and multiple of 2) k-quant blocks per tile
|
|
58
|
+
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
|
|
59
|
+
|
|
60
|
+
// default size for legacy matrix multiplication
|
|
61
|
+
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
|
62
|
+
|
|
63
|
+
// Same hash combine function as in boost
|
|
64
|
+
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
|
|
65
|
+
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
struct ggml_webgpu_shader_lib_context {
|
|
69
|
+
ggml_tensor * src0;
|
|
70
|
+
ggml_tensor * src1;
|
|
71
|
+
ggml_tensor * src2;
|
|
72
|
+
ggml_tensor * src3;
|
|
73
|
+
ggml_tensor * src4;
|
|
74
|
+
ggml_tensor * dst;
|
|
75
|
+
|
|
76
|
+
uint32_t max_wg_size;
|
|
77
|
+
size_t wg_mem_limit_bytes = 0;
|
|
78
|
+
bool inplace = false;
|
|
79
|
+
bool overlap = false;
|
|
80
|
+
bool src_overlap = false;
|
|
81
|
+
bool supports_subgroup_matrix = false;
|
|
82
|
+
uint32_t sg_mat_m = 0;
|
|
83
|
+
uint32_t sg_mat_n = 0;
|
|
84
|
+
uint32_t sg_mat_k = 0;
|
|
85
|
+
uint32_t max_subgroup_size = 0;
|
|
86
|
+
};
|
|
87
|
+
|
|
88
|
+
struct webgpu_pipeline {
|
|
89
|
+
wgpu::ComputePipeline pipeline;
|
|
90
|
+
std::string name;
|
|
91
|
+
std::shared_ptr<void> context = nullptr;
|
|
92
|
+
};
|
|
93
|
+
|
|
94
|
+
struct ggml_webgpu_generic_shader_decisions {
|
|
95
|
+
uint32_t wg_size = 0;
|
|
96
|
+
};
|
|
97
|
+
|
|
98
|
+
/** Argsort **/
|
|
99
|
+
|
|
100
|
+
struct ggml_webgpu_argsort_shader_lib_context {
|
|
101
|
+
uint32_t max_wg_size;
|
|
102
|
+
size_t wg_mem_limit_bytes;
|
|
103
|
+
int32_t order;
|
|
104
|
+
};
|
|
105
|
+
|
|
106
|
+
/** Set Rows **/
|
|
107
|
+
|
|
108
|
+
struct ggml_webgpu_set_rows_pipeline_key {
|
|
109
|
+
int dst_type;
|
|
110
|
+
int vec4;
|
|
111
|
+
int i64_idx;
|
|
112
|
+
|
|
113
|
+
bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
|
|
114
|
+
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
|
|
115
|
+
}
|
|
116
|
+
};
|
|
117
|
+
|
|
118
|
+
struct ggml_webgpu_set_rows_pipeline_key_hash {
|
|
119
|
+
size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
|
|
120
|
+
size_t seed = 0;
|
|
121
|
+
ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
122
|
+
ggml_webgpu_hash_combine(seed, key.vec4);
|
|
123
|
+
ggml_webgpu_hash_combine(seed, key.i64_idx);
|
|
124
|
+
return seed;
|
|
125
|
+
}
|
|
126
|
+
};
|
|
127
|
+
|
|
128
|
+
struct ggml_webgpu_set_rows_shader_decisions {
|
|
129
|
+
bool vec4;
|
|
130
|
+
bool i64_idx;
|
|
131
|
+
uint32_t wg_size;
|
|
132
|
+
};
|
|
133
|
+
|
|
134
|
+
/** Get Rows **/
|
|
135
|
+
|
|
136
|
+
struct ggml_webgpu_get_rows_pipeline_key {
|
|
137
|
+
ggml_type src_type;
|
|
138
|
+
int vectorized;
|
|
139
|
+
|
|
140
|
+
bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const {
|
|
141
|
+
return src_type == other.src_type && vectorized == other.vectorized;
|
|
142
|
+
}
|
|
143
|
+
};
|
|
144
|
+
|
|
145
|
+
struct ggml_webgpu_get_rows_pipeline_key_hash {
|
|
146
|
+
size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const {
|
|
147
|
+
size_t seed = 0;
|
|
148
|
+
ggml_webgpu_hash_combine(seed, key.src_type);
|
|
149
|
+
ggml_webgpu_hash_combine(seed, key.vectorized);
|
|
150
|
+
return seed;
|
|
151
|
+
}
|
|
152
|
+
};
|
|
153
|
+
|
|
154
|
+
/** Pad **/
|
|
155
|
+
struct ggml_webgpu_pad_pipeline_key {
|
|
156
|
+
bool circular;
|
|
157
|
+
|
|
158
|
+
bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
|
|
159
|
+
};
|
|
160
|
+
|
|
161
|
+
struct ggml_webgpu_pad_pipeline_key_hash {
|
|
162
|
+
size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
|
|
163
|
+
size_t seed = 0;
|
|
164
|
+
ggml_webgpu_hash_combine(seed, key.circular);
|
|
165
|
+
return seed;
|
|
166
|
+
}
|
|
167
|
+
};
|
|
168
|
+
|
|
169
|
+
/** Scale **/
|
|
170
|
+
|
|
171
|
+
struct ggml_webgpu_scale_pipeline_key {
|
|
172
|
+
int inplace;
|
|
173
|
+
|
|
174
|
+
bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; }
|
|
175
|
+
};
|
|
176
|
+
|
|
177
|
+
struct ggml_webgpu_scale_pipeline_key_hash {
|
|
178
|
+
size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const {
|
|
179
|
+
size_t seed = 0;
|
|
180
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
181
|
+
return seed;
|
|
182
|
+
}
|
|
183
|
+
};
|
|
184
|
+
|
|
185
|
+
/** Concat **/
|
|
186
|
+
|
|
187
|
+
struct ggml_webgpu_concat_pipeline_key {
|
|
188
|
+
int type;
|
|
189
|
+
|
|
190
|
+
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; }
|
|
191
|
+
};
|
|
192
|
+
|
|
193
|
+
struct ggml_webgpu_concat_pipeline_key_hash {
|
|
194
|
+
size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
|
|
195
|
+
size_t seed = 0;
|
|
196
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
197
|
+
return seed;
|
|
198
|
+
}
|
|
199
|
+
};
|
|
200
|
+
|
|
201
|
+
/** Repeat **/
|
|
202
|
+
|
|
203
|
+
struct ggml_webgpu_repeat_pipeline_key {
|
|
204
|
+
int type;
|
|
205
|
+
|
|
206
|
+
bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; }
|
|
207
|
+
};
|
|
208
|
+
|
|
209
|
+
struct ggml_webgpu_repeat_pipeline_key_hash {
|
|
210
|
+
size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const {
|
|
211
|
+
size_t seed = 0;
|
|
212
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
213
|
+
return seed;
|
|
214
|
+
}
|
|
215
|
+
};
|
|
216
|
+
|
|
217
|
+
/** Binary **/
|
|
218
|
+
|
|
219
|
+
struct ggml_webgpu_binary_pipeline_key {
|
|
220
|
+
int type;
|
|
221
|
+
int op;
|
|
222
|
+
bool inplace;
|
|
223
|
+
bool overlap;
|
|
224
|
+
bool src_overlap;
|
|
225
|
+
|
|
226
|
+
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
|
|
227
|
+
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
|
|
228
|
+
src_overlap == other.src_overlap;
|
|
229
|
+
}
|
|
230
|
+
};
|
|
231
|
+
|
|
232
|
+
struct ggml_webgpu_binary_pipeline_key_hash {
|
|
233
|
+
size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
|
|
234
|
+
size_t seed = 0;
|
|
235
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
236
|
+
ggml_webgpu_hash_combine(seed, key.op);
|
|
237
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
238
|
+
ggml_webgpu_hash_combine(seed, key.overlap);
|
|
239
|
+
ggml_webgpu_hash_combine(seed, key.src_overlap);
|
|
240
|
+
return seed;
|
|
241
|
+
}
|
|
242
|
+
};
|
|
243
|
+
|
|
244
|
+
/** Unary **/
|
|
245
|
+
|
|
246
|
+
struct ggml_webgpu_unary_pipeline_key {
|
|
247
|
+
int type;
|
|
248
|
+
int op;
|
|
249
|
+
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
|
|
250
|
+
bool inplace;
|
|
251
|
+
|
|
252
|
+
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
|
|
253
|
+
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
|
|
254
|
+
}
|
|
255
|
+
};
|
|
256
|
+
|
|
257
|
+
struct ggml_webgpu_unary_pipeline_key_hash {
|
|
258
|
+
size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
|
|
259
|
+
size_t seed = 0;
|
|
260
|
+
ggml_webgpu_hash_combine(seed, key.type);
|
|
261
|
+
ggml_webgpu_hash_combine(seed, key.op);
|
|
262
|
+
ggml_webgpu_hash_combine(seed, key.is_unary);
|
|
263
|
+
ggml_webgpu_hash_combine(seed, key.inplace);
|
|
264
|
+
return seed;
|
|
265
|
+
}
|
|
266
|
+
};
|
|
267
|
+
|
|
268
|
+
/** FlashAttention */
|
|
269
|
+
|
|
270
|
+
struct ggml_webgpu_flash_attn_pipeline_key {
|
|
18
271
|
ggml_type kv_type;
|
|
19
272
|
uint32_t head_dim_qk;
|
|
20
273
|
uint32_t head_dim_v;
|
|
@@ -22,11 +275,35 @@ struct ggml_webgpu_flash_attn_shader_lib_context {
|
|
|
22
275
|
bool has_mask;
|
|
23
276
|
bool has_sinks;
|
|
24
277
|
bool uses_logit_softcap;
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
278
|
+
|
|
279
|
+
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
|
280
|
+
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
|
281
|
+
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
|
282
|
+
uses_logit_softcap == other.uses_logit_softcap;
|
|
283
|
+
}
|
|
284
|
+
};
|
|
285
|
+
|
|
286
|
+
struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
|
287
|
+
size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
|
|
288
|
+
size_t seed = 0;
|
|
289
|
+
ggml_webgpu_hash_combine(seed, key.kv_type);
|
|
290
|
+
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
|
291
|
+
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
|
292
|
+
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
|
293
|
+
ggml_webgpu_hash_combine(seed, key.has_mask);
|
|
294
|
+
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
|
295
|
+
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
|
296
|
+
return seed;
|
|
297
|
+
}
|
|
298
|
+
};
|
|
299
|
+
|
|
300
|
+
struct ggml_webgpu_flash_attn_shader_lib_context {
|
|
301
|
+
ggml_webgpu_flash_attn_pipeline_key key;
|
|
302
|
+
uint32_t sg_mat_m;
|
|
303
|
+
uint32_t sg_mat_n;
|
|
304
|
+
uint32_t sg_mat_k;
|
|
305
|
+
size_t wg_mem_limit_bytes;
|
|
306
|
+
uint32_t max_subgroup_size;
|
|
30
307
|
};
|
|
31
308
|
|
|
32
309
|
struct ggml_webgpu_flash_attn_shader_decisions {
|
|
@@ -35,12 +312,6 @@ struct ggml_webgpu_flash_attn_shader_decisions {
|
|
|
35
312
|
uint32_t wg_size = 0;
|
|
36
313
|
};
|
|
37
314
|
|
|
38
|
-
struct ggml_webgpu_processed_shader {
|
|
39
|
-
std::string wgsl;
|
|
40
|
-
std::string variant;
|
|
41
|
-
ggml_webgpu_flash_attn_shader_decisions decisions;
|
|
42
|
-
};
|
|
43
|
-
|
|
44
315
|
// This is exposed because it's necessary in supports_op
|
|
45
316
|
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
|
46
317
|
uint32_t kv_tile,
|
|
@@ -65,105 +336,1039 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
|
|
65
336
|
return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
|
|
66
337
|
}
|
|
67
338
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
339
|
+
/** Matrix Multiplication **/
|
|
340
|
+
|
|
341
|
+
struct ggml_webgpu_legacy_mul_mat_pipeline_key {
|
|
342
|
+
ggml_type src0_type;
|
|
343
|
+
ggml_type src1_type;
|
|
344
|
+
|
|
345
|
+
bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const {
|
|
346
|
+
return src0_type == other.src0_type && src1_type == other.src1_type;
|
|
76
347
|
}
|
|
77
|
-
|
|
78
|
-
|
|
348
|
+
};
|
|
349
|
+
|
|
350
|
+
struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash {
|
|
351
|
+
size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const {
|
|
352
|
+
size_t seed = 0;
|
|
353
|
+
ggml_webgpu_hash_combine(seed, key.src0_type);
|
|
354
|
+
ggml_webgpu_hash_combine(seed, key.src1_type);
|
|
355
|
+
return seed;
|
|
79
356
|
}
|
|
80
|
-
|
|
81
|
-
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
|
82
|
-
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
|
83
|
-
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
|
84
|
-
}
|
|
357
|
+
};
|
|
85
358
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
}
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
uint32_t
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
359
|
+
struct ggml_webgpu_mul_mat_vec_pipeline_key {
|
|
360
|
+
ggml_type src0_type;
|
|
361
|
+
ggml_type src1_type;
|
|
362
|
+
int vectorized;
|
|
363
|
+
|
|
364
|
+
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
|
|
365
|
+
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized;
|
|
366
|
+
}
|
|
367
|
+
};
|
|
368
|
+
|
|
369
|
+
struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
|
|
370
|
+
size_t operator()(const ggml_webgpu_mul_mat_vec_pipeline_key & key) const {
|
|
371
|
+
size_t seed = 0;
|
|
372
|
+
ggml_webgpu_hash_combine(seed, key.src0_type);
|
|
373
|
+
ggml_webgpu_hash_combine(seed, key.src1_type);
|
|
374
|
+
ggml_webgpu_hash_combine(seed, key.vectorized);
|
|
375
|
+
return seed;
|
|
376
|
+
}
|
|
377
|
+
};
|
|
378
|
+
|
|
379
|
+
struct ggml_webgpu_mul_mat_vec_shader_decisions {
|
|
380
|
+
uint32_t wg_size;
|
|
381
|
+
uint32_t tile_k;
|
|
382
|
+
uint32_t outputs_per_wg;
|
|
383
|
+
uint32_t vec_size;
|
|
384
|
+
};
|
|
385
|
+
|
|
386
|
+
struct ggml_webgpu_mul_mat_pipeline_key {
|
|
387
|
+
ggml_type src0_type;
|
|
388
|
+
ggml_type src1_type;
|
|
389
|
+
int vectorized;
|
|
390
|
+
int use_subgroup_matrix;
|
|
391
|
+
|
|
392
|
+
bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const {
|
|
393
|
+
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
|
|
394
|
+
use_subgroup_matrix == other.use_subgroup_matrix;
|
|
395
|
+
}
|
|
396
|
+
};
|
|
397
|
+
|
|
398
|
+
struct ggml_webgpu_mul_mat_pipeline_key_hash {
|
|
399
|
+
size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const {
|
|
400
|
+
size_t seed = 0;
|
|
401
|
+
ggml_webgpu_hash_combine(seed, key.src0_type);
|
|
402
|
+
ggml_webgpu_hash_combine(seed, key.src1_type);
|
|
403
|
+
ggml_webgpu_hash_combine(seed, key.vectorized);
|
|
404
|
+
ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix);
|
|
405
|
+
return seed;
|
|
406
|
+
}
|
|
407
|
+
};
|
|
408
|
+
|
|
409
|
+
struct ggml_webgpu_mul_mat_shader_decisions {
|
|
410
|
+
uint32_t tile_k;
|
|
411
|
+
uint32_t wg_size_m;
|
|
412
|
+
uint32_t wg_size_n;
|
|
413
|
+
uint32_t wg_size;
|
|
414
|
+
uint32_t outputs_per_wg;
|
|
415
|
+
int use_subgroup_matrix;
|
|
416
|
+
|
|
417
|
+
uint32_t tile_m;
|
|
418
|
+
uint32_t tile_n;
|
|
419
|
+
|
|
420
|
+
// Subgroup matrix parameters
|
|
421
|
+
uint32_t subgroup_m;
|
|
422
|
+
uint32_t subgroup_n;
|
|
423
|
+
uint32_t subgroup_matrix_m;
|
|
424
|
+
uint32_t subgroup_matrix_n;
|
|
425
|
+
|
|
426
|
+
uint32_t mul_mat_wg_size;
|
|
427
|
+
};
|
|
428
|
+
|
|
429
|
+
class ggml_webgpu_shader_lib {
|
|
430
|
+
wgpu::Device device;
|
|
431
|
+
pre_wgsl::Preprocessor preprocessor;
|
|
432
|
+
|
|
433
|
+
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
|
|
434
|
+
std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
|
|
435
|
+
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
|
|
436
|
+
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
|
|
437
|
+
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
|
|
438
|
+
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
|
|
439
|
+
get_rows_pipelines; // src_type, vectorized
|
|
440
|
+
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
|
|
441
|
+
unary_pipelines; // type/op/inplace
|
|
442
|
+
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
|
|
443
|
+
scale_pipelines; // inplace
|
|
444
|
+
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
|
|
445
|
+
pad_pipelines; // circular/non-circular
|
|
446
|
+
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
|
|
447
|
+
binary_pipelines; // type/op/inplace/overlap
|
|
448
|
+
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
|
|
449
|
+
concat_pipelines; // type
|
|
450
|
+
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
|
|
451
|
+
repeat_pipelines; // type
|
|
452
|
+
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
|
453
|
+
flash_attn_pipelines;
|
|
454
|
+
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
|
|
455
|
+
webgpu_pipeline,
|
|
456
|
+
ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
|
|
457
|
+
mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec)
|
|
458
|
+
std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
|
|
459
|
+
mul_mat_vec_pipelines; // fast mat-vec (n==1)
|
|
460
|
+
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
|
|
461
|
+
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
|
|
462
|
+
|
|
463
|
+
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
|
464
|
+
set_rows_pipelines;
|
|
465
|
+
|
|
466
|
+
public:
|
|
467
|
+
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
|
468
|
+
|
|
469
|
+
webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
470
|
+
auto it = sum_rows_pipelines.find(1);
|
|
471
|
+
if (it != sum_rows_pipelines.end()) {
|
|
472
|
+
return it->second;
|
|
473
|
+
}
|
|
474
|
+
std::vector<std::string> defines;
|
|
475
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
476
|
+
|
|
477
|
+
auto processed = preprocessor.preprocess(wgsl_sum_rows, defines);
|
|
478
|
+
sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows");
|
|
479
|
+
return sum_rows_pipelines[1];
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
483
|
+
bool vec4 = context.src0->ne[0] % 4 == 0;
|
|
484
|
+
|
|
485
|
+
auto it = argmax_pipelines.find(vec4);
|
|
486
|
+
if (it != argmax_pipelines.end()) {
|
|
487
|
+
return it->second;
|
|
488
|
+
}
|
|
489
|
+
std::string variant = "argmax";
|
|
490
|
+
std::vector<std::string> defines;
|
|
491
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
492
|
+
if (vec4) {
|
|
493
|
+
defines.push_back("VEC4");
|
|
494
|
+
variant += "_vec4";
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
auto processed = preprocessor.preprocess(wgsl_argmax, defines);
|
|
498
|
+
argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
499
|
+
return argmax_pipelines.at(vec4);
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
503
|
+
ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type,
|
|
504
|
+
.vec4 = context.src0->ne[0] % 4 == 0,
|
|
505
|
+
.i64_idx = context.src1->type == GGML_TYPE_I64 };
|
|
506
|
+
|
|
507
|
+
auto it = set_rows_pipelines.find(key);
|
|
508
|
+
if (it != set_rows_pipelines.end()) {
|
|
509
|
+
return it->second;
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
std::vector<std::string> defines;
|
|
513
|
+
std::string variant = "set_rows";
|
|
514
|
+
|
|
515
|
+
switch (context.dst->type) {
|
|
516
|
+
case GGML_TYPE_F32:
|
|
517
|
+
defines.push_back("DST_F32");
|
|
518
|
+
variant += "_dstf32";
|
|
519
|
+
break;
|
|
520
|
+
case GGML_TYPE_F16:
|
|
521
|
+
defines.push_back("DST_F16");
|
|
522
|
+
variant += "_dstf16";
|
|
523
|
+
break;
|
|
524
|
+
default:
|
|
525
|
+
GGML_ABORT("Unsupported dst type for set_rows shader");
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
if (key.vec4) {
|
|
529
|
+
defines.push_back("VEC4");
|
|
530
|
+
variant += "_vec4";
|
|
531
|
+
}
|
|
532
|
+
if (key.i64_idx) {
|
|
533
|
+
defines.push_back("I64_IDX");
|
|
534
|
+
variant += "_i64idx";
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
538
|
+
|
|
539
|
+
auto processed = preprocessor.preprocess(wgsl_set_rows, defines);
|
|
540
|
+
auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
|
|
541
|
+
decisions->vec4 = key.vec4;
|
|
542
|
+
decisions->i64_idx = key.i64_idx;
|
|
543
|
+
decisions->wg_size = context.max_wg_size;
|
|
544
|
+
set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
545
|
+
set_rows_pipelines[key].context = decisions;
|
|
546
|
+
return set_rows_pipelines[key];
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
550
|
+
auto it = cumsum_pipelines.find(1);
|
|
551
|
+
if (it != cumsum_pipelines.end()) {
|
|
552
|
+
return it->second;
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
std::vector<std::string> defines;
|
|
556
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
557
|
+
|
|
558
|
+
auto processed = preprocessor.preprocess(wgsl_cumsum, defines);
|
|
559
|
+
cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "cumsum");
|
|
560
|
+
return cumsum_pipelines[1];
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
564
|
+
bool is_top_k = context.dst->op == GGML_OP_TOP_K;
|
|
565
|
+
// ascending order is 0, descending order is 1
|
|
566
|
+
const int32_t order =
|
|
567
|
+
is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
|
|
568
|
+
|
|
569
|
+
auto it = argsort_pipelines.find(order);
|
|
570
|
+
if (it != argsort_pipelines.end()) {
|
|
571
|
+
return it->second;
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
std::vector<std::string> defines;
|
|
575
|
+
std::string variant = "argsort";
|
|
576
|
+
defines.push_back(std::string("ORDER=") + std::to_string(order));
|
|
577
|
+
variant += std::string("_order") + std::to_string(order);
|
|
578
|
+
uint32_t wg_size = 1;
|
|
579
|
+
while (wg_size * 2 <= context.max_wg_size &&
|
|
580
|
+
wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
|
|
581
|
+
wg_size *= 2;
|
|
582
|
+
}
|
|
583
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
584
|
+
auto processed = preprocessor.preprocess(wgsl_argsort, defines);
|
|
585
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
586
|
+
decisions->wg_size = wg_size;
|
|
587
|
+
argsort_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
588
|
+
argsort_pipelines[order].context = decisions;
|
|
589
|
+
return argsort_pipelines[order];
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
593
|
+
bool is_top_k = context.dst->op == GGML_OP_TOP_K;
|
|
594
|
+
// ascending order is 0, descending order is 1
|
|
595
|
+
const int32_t order =
|
|
596
|
+
is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
|
|
597
|
+
|
|
598
|
+
auto it = argsort_merge_pipelines.find(order);
|
|
599
|
+
if (it != argsort_merge_pipelines.end()) {
|
|
600
|
+
return it->second;
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
std::vector<std::string> defines;
|
|
604
|
+
std::string variant = "argsort_merge";
|
|
605
|
+
defines.push_back(std::string("ORDER=") + std::to_string(order));
|
|
606
|
+
variant += std::string("_order") + std::to_string(order);
|
|
607
|
+
uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
|
|
608
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
609
|
+
|
|
610
|
+
auto processed = preprocessor.preprocess(wgsl_argsort_merge, defines);
|
|
611
|
+
argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
612
|
+
return argsort_merge_pipelines[order];
|
|
613
|
+
}
|
|
614
|
+
|
|
615
|
+
webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
616
|
+
const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0;
|
|
617
|
+
ggml_webgpu_get_rows_pipeline_key key = {
|
|
618
|
+
.src_type = context.src0->type,
|
|
619
|
+
.vectorized = (int) vectorized,
|
|
620
|
+
};
|
|
621
|
+
|
|
622
|
+
auto it = get_rows_pipelines.find(key);
|
|
623
|
+
if (it != get_rows_pipelines.end()) {
|
|
624
|
+
return it->second;
|
|
625
|
+
}
|
|
626
|
+
|
|
627
|
+
std::vector<std::string> defines;
|
|
628
|
+
std::string variant = "get_rows";
|
|
629
|
+
|
|
630
|
+
const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type);
|
|
631
|
+
const char * type_str = type_traits->type_name;
|
|
632
|
+
|
|
633
|
+
switch (key.src_type) {
|
|
634
|
+
case GGML_TYPE_F32:
|
|
635
|
+
if (key.vectorized) {
|
|
636
|
+
defines.push_back("F32_VEC");
|
|
637
|
+
defines.push_back("SRC_TYPE=vec4<f32>");
|
|
638
|
+
defines.push_back("DST_TYPE=vec4<f32>");
|
|
639
|
+
defines.push_back("BLOCK_SIZE=4u");
|
|
640
|
+
} else {
|
|
641
|
+
defines.push_back("F32");
|
|
642
|
+
defines.push_back("SRC_TYPE=f32");
|
|
643
|
+
defines.push_back("DST_TYPE=f32");
|
|
644
|
+
defines.push_back("BLOCK_SIZE=1u");
|
|
645
|
+
}
|
|
646
|
+
variant += "_f32";
|
|
647
|
+
break;
|
|
648
|
+
case GGML_TYPE_F16:
|
|
649
|
+
defines.push_back("F16");
|
|
650
|
+
defines.push_back("SRC_TYPE=f16");
|
|
651
|
+
defines.push_back("DST_TYPE=f32");
|
|
652
|
+
defines.push_back("BLOCK_SIZE=1u");
|
|
653
|
+
variant += "_f16";
|
|
654
|
+
break;
|
|
655
|
+
case GGML_TYPE_I32:
|
|
656
|
+
defines.push_back("I32");
|
|
657
|
+
defines.push_back("SRC_TYPE=i32");
|
|
658
|
+
defines.push_back("DST_TYPE=i32");
|
|
659
|
+
defines.push_back("BLOCK_SIZE=1u");
|
|
660
|
+
variant += "_i32";
|
|
661
|
+
break;
|
|
662
|
+
default:
|
|
663
|
+
{
|
|
664
|
+
std::string type_upper = type_str;
|
|
665
|
+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
666
|
+
|
|
667
|
+
defines.push_back("BYTE_HELPERS");
|
|
668
|
+
defines.push_back(type_upper + "_T");
|
|
669
|
+
defines.push_back(type_upper);
|
|
670
|
+
defines.push_back(type_upper + "_SCALE_MIN");
|
|
671
|
+
defines.push_back(type_upper + "_TABLES");
|
|
672
|
+
defines.push_back(type_upper + "_GRID");
|
|
673
|
+
|
|
674
|
+
variant += "_";
|
|
675
|
+
variant += type_str;
|
|
676
|
+
|
|
677
|
+
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
|
678
|
+
defines.push_back("DST_TYPE=f32");
|
|
679
|
+
|
|
680
|
+
if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
|
681
|
+
key.src_type == GGML_TYPE_IQ4_NL) {
|
|
682
|
+
defines.push_back("BLOCK_SIZE=32u");
|
|
683
|
+
} else if (key.src_type >= GGML_TYPE_Q2_K) {
|
|
684
|
+
defines.push_back("BLOCK_SIZE=256u");
|
|
685
|
+
} else {
|
|
686
|
+
defines.push_back("BLOCK_SIZE=1u");
|
|
687
|
+
}
|
|
688
|
+
break;
|
|
689
|
+
}
|
|
690
|
+
}
|
|
691
|
+
|
|
692
|
+
if (key.vectorized) {
|
|
693
|
+
variant += "_vec";
|
|
694
|
+
}
|
|
695
|
+
|
|
696
|
+
defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size));
|
|
697
|
+
|
|
698
|
+
auto processed = preprocessor.preprocess(wgsl_get_rows, defines);
|
|
699
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
700
|
+
decisions->wg_size = context.max_wg_size;
|
|
701
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
702
|
+
pipeline.context = decisions;
|
|
703
|
+
get_rows_pipelines[key] = pipeline;
|
|
704
|
+
return get_rows_pipelines[key];
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
708
|
+
ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace };
|
|
709
|
+
|
|
710
|
+
auto it = scale_pipelines.find(key);
|
|
711
|
+
if (it != scale_pipelines.end()) {
|
|
712
|
+
return it->second;
|
|
713
|
+
}
|
|
714
|
+
|
|
715
|
+
std::vector<std::string> defines;
|
|
716
|
+
std::string variant = "scale";
|
|
717
|
+
|
|
718
|
+
if (key.inplace) {
|
|
719
|
+
defines.push_back("INPLACE");
|
|
720
|
+
variant += "_inplace";
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
724
|
+
|
|
725
|
+
auto processed = preprocessor.preprocess(wgsl_scale, defines);
|
|
726
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
727
|
+
decisions->wg_size = context.max_wg_size;
|
|
728
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
729
|
+
pipeline.context = decisions;
|
|
730
|
+
scale_pipelines[key] = pipeline;
|
|
731
|
+
return scale_pipelines[key];
|
|
732
|
+
}
|
|
733
|
+
|
|
734
|
+
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
735
|
+
ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
|
|
736
|
+
|
|
737
|
+
auto it = pad_pipelines.find(key);
|
|
738
|
+
if (it != pad_pipelines.end()) {
|
|
739
|
+
return it->second;
|
|
740
|
+
}
|
|
741
|
+
|
|
742
|
+
std::vector<std::string> defines;
|
|
743
|
+
std::string variant = "pad";
|
|
744
|
+
|
|
745
|
+
if (key.circular) {
|
|
746
|
+
defines.push_back("CIRCULAR");
|
|
747
|
+
variant += "_circular";
|
|
748
|
+
}
|
|
749
|
+
|
|
750
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
751
|
+
|
|
752
|
+
auto processed = preprocessor.preprocess(wgsl_pad, defines);
|
|
753
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
754
|
+
decisions->wg_size = context.max_wg_size;
|
|
755
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
756
|
+
pipeline.context = decisions;
|
|
757
|
+
pad_pipelines[key] = pipeline;
|
|
758
|
+
return pad_pipelines[key];
|
|
759
|
+
}
|
|
760
|
+
|
|
761
|
+
webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
762
|
+
ggml_webgpu_mul_mat_vec_pipeline_key key = {
|
|
763
|
+
.src0_type = context.src0->type,
|
|
764
|
+
.src1_type = context.src1->type,
|
|
765
|
+
// Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float
|
|
766
|
+
.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
|
|
767
|
+
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
|
768
|
+
1 :
|
|
769
|
+
0,
|
|
770
|
+
};
|
|
771
|
+
|
|
772
|
+
auto it = mul_mat_vec_pipelines.find(key);
|
|
773
|
+
if (it != mul_mat_vec_pipelines.end()) {
|
|
774
|
+
return it->second;
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
std::vector<std::string> defines;
|
|
778
|
+
std::string variant = "mul_mat_vec";
|
|
779
|
+
|
|
780
|
+
// src0 type (matrix row)
|
|
781
|
+
switch (context.src0->type) {
|
|
782
|
+
case GGML_TYPE_F32:
|
|
783
|
+
defines.push_back("SRC0_INNER_TYPE=f32");
|
|
784
|
+
defines.push_back("MUL_ACC_FLOAT");
|
|
785
|
+
variant += "_f32";
|
|
786
|
+
break;
|
|
787
|
+
case GGML_TYPE_F16:
|
|
788
|
+
defines.push_back("SRC0_INNER_TYPE=f16");
|
|
789
|
+
defines.push_back("MUL_ACC_FLOAT");
|
|
790
|
+
variant += "_f16";
|
|
791
|
+
break;
|
|
792
|
+
default:
|
|
793
|
+
{
|
|
794
|
+
// Quantized types: use helpers but accumulate in f16
|
|
795
|
+
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
|
796
|
+
std::string src0_name = src0_traits->type_name;
|
|
797
|
+
std::string type_upper = src0_name;
|
|
798
|
+
variant += "_" + src0_name;
|
|
799
|
+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
800
|
+
|
|
801
|
+
defines.push_back("BYTE_HELPERS");
|
|
802
|
+
defines.push_back("MUL_ACC_" + type_upper);
|
|
803
|
+
|
|
804
|
+
// For fast path we always dequantize from f16 inside the shader
|
|
805
|
+
defines.push_back("SRC0_INNER_TYPE=f16");
|
|
806
|
+
break;
|
|
807
|
+
}
|
|
808
|
+
}
|
|
809
|
+
|
|
810
|
+
// src1 type (vector)
|
|
811
|
+
switch (context.src1->type) {
|
|
812
|
+
case GGML_TYPE_F32:
|
|
813
|
+
defines.push_back("SRC1_INNER_TYPE=f32");
|
|
814
|
+
variant += "_f32";
|
|
815
|
+
break;
|
|
816
|
+
case GGML_TYPE_F16:
|
|
817
|
+
defines.push_back("SRC1_INNER_TYPE=f16");
|
|
818
|
+
variant += "_f16";
|
|
819
|
+
break;
|
|
820
|
+
default:
|
|
821
|
+
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
// VEC/SCALAR controls
|
|
825
|
+
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
|
826
|
+
|
|
827
|
+
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
|
828
|
+
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
|
|
829
|
+
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
|
|
830
|
+
|
|
831
|
+
if (key.src0_type >= GGML_TYPE_Q2_K) {
|
|
832
|
+
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
|
|
833
|
+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
|
|
834
|
+
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
|
|
835
|
+
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
|
|
836
|
+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
|
837
|
+
}
|
|
838
|
+
|
|
839
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
840
|
+
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
|
|
841
|
+
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
|
|
842
|
+
|
|
843
|
+
auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines);
|
|
844
|
+
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
|
|
845
|
+
decisions->wg_size = wg_size;
|
|
846
|
+
decisions->tile_k = tile_k;
|
|
847
|
+
decisions->outputs_per_wg = outputs_per_wg;
|
|
848
|
+
decisions->vec_size = key.vectorized ? 4 : 1;
|
|
849
|
+
|
|
850
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
851
|
+
pipeline.context = decisions;
|
|
852
|
+
mul_mat_vec_pipelines[key] = pipeline;
|
|
853
|
+
return mul_mat_vec_pipelines[key];
|
|
854
|
+
}
|
|
855
|
+
|
|
856
|
+
webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
857
|
+
ggml_webgpu_mul_mat_pipeline_key key = {
|
|
858
|
+
.src0_type = context.src0->type,
|
|
859
|
+
.src1_type = context.src1->type,
|
|
860
|
+
.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 &&
|
|
861
|
+
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
|
862
|
+
1 :
|
|
863
|
+
0,
|
|
864
|
+
.use_subgroup_matrix = context.supports_subgroup_matrix
|
|
865
|
+
};
|
|
866
|
+
|
|
867
|
+
auto it = mul_mat_fast_pipelines.find(key);
|
|
868
|
+
if (it != mul_mat_fast_pipelines.end()) {
|
|
869
|
+
return it->second;
|
|
870
|
+
}
|
|
871
|
+
|
|
872
|
+
const char * shader_src = key.use_subgroup_matrix ? wgsl_mul_mat_subgroup_matrix : wgsl_mul_mat_reg_tile;
|
|
873
|
+
std::vector<std::string> defines;
|
|
874
|
+
std::string variant = key.use_subgroup_matrix ? "mul_mat_subgroup_matrix" : "mul_mat_reg_tile";
|
|
875
|
+
|
|
876
|
+
// src1 type
|
|
877
|
+
switch (context.src1->type) {
|
|
878
|
+
case GGML_TYPE_F32:
|
|
879
|
+
defines.push_back("SRC1_INNER_TYPE=f32");
|
|
880
|
+
break;
|
|
881
|
+
case GGML_TYPE_F16:
|
|
882
|
+
defines.push_back("SRC1_INNER_TYPE=f16");
|
|
883
|
+
break;
|
|
884
|
+
default:
|
|
885
|
+
GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
|
|
886
|
+
}
|
|
887
|
+
|
|
888
|
+
// src0 type
|
|
889
|
+
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
|
890
|
+
const char * src0_name = src0_traits->type_name;
|
|
891
|
+
|
|
892
|
+
switch (context.src0->type) {
|
|
893
|
+
case GGML_TYPE_F32:
|
|
894
|
+
defines.push_back("SRC0_INNER_TYPE=f32");
|
|
895
|
+
defines.push_back("FLOAT");
|
|
896
|
+
defines.push_back("MUL_ACC_FLOAT");
|
|
897
|
+
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
|
898
|
+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
|
899
|
+
variant += "_f32";
|
|
900
|
+
break;
|
|
901
|
+
case GGML_TYPE_F16:
|
|
902
|
+
defines.push_back("SRC0_INNER_TYPE=f16");
|
|
903
|
+
defines.push_back("FLOAT");
|
|
904
|
+
defines.push_back("MUL_ACC_FLOAT");
|
|
905
|
+
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
|
906
|
+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
|
907
|
+
variant += "_f16";
|
|
908
|
+
break;
|
|
909
|
+
default:
|
|
910
|
+
{
|
|
911
|
+
std::string type_upper = src0_name;
|
|
912
|
+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
913
|
+
|
|
914
|
+
defines.push_back("BYTE_HELPERS");
|
|
915
|
+
defines.push_back("MUL_ACC_" + type_upper);
|
|
916
|
+
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
|
|
917
|
+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
|
918
|
+
|
|
919
|
+
// Use f16 inside the shader for quantized types
|
|
920
|
+
defines.push_back("SRC0_INNER_TYPE=f16");
|
|
921
|
+
|
|
922
|
+
variant += std::string("_") + src0_name;
|
|
923
|
+
break;
|
|
924
|
+
}
|
|
925
|
+
}
|
|
926
|
+
|
|
927
|
+
// VEC/SCALAR controls
|
|
928
|
+
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
|
929
|
+
|
|
930
|
+
// Tiles
|
|
931
|
+
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
|
932
|
+
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
|
|
933
|
+
defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
|
|
934
|
+
|
|
935
|
+
// Subgroup matrix specifics
|
|
936
|
+
if (key.use_subgroup_matrix) {
|
|
937
|
+
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
|
|
938
|
+
defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
|
|
939
|
+
defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
|
|
940
|
+
defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u");
|
|
941
|
+
defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u");
|
|
942
|
+
defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u");
|
|
943
|
+
defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u");
|
|
944
|
+
defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u");
|
|
945
|
+
}
|
|
946
|
+
|
|
947
|
+
// variant suffix for src1 type
|
|
948
|
+
variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
|
|
949
|
+
if (key.vectorized) {
|
|
950
|
+
variant += "_vectorized";
|
|
951
|
+
}
|
|
952
|
+
|
|
953
|
+
if (!key.use_subgroup_matrix) {
|
|
954
|
+
defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
|
|
955
|
+
defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
|
|
956
|
+
}
|
|
957
|
+
|
|
958
|
+
auto processed = preprocessor.preprocess(shader_src, defines);
|
|
959
|
+
|
|
960
|
+
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
|
|
961
|
+
decisions->tile_k = WEBGPU_MUL_MAT_TILE_K;
|
|
962
|
+
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
|
|
963
|
+
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
|
|
964
|
+
decisions->use_subgroup_matrix = key.use_subgroup_matrix;
|
|
965
|
+
if (key.use_subgroup_matrix) {
|
|
966
|
+
decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M;
|
|
967
|
+
decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N;
|
|
968
|
+
decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M;
|
|
969
|
+
decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N;
|
|
970
|
+
decisions->wg_size = context.max_subgroup_size;
|
|
971
|
+
} else {
|
|
972
|
+
decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
|
|
973
|
+
decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
|
|
974
|
+
decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N;
|
|
975
|
+
decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE;
|
|
976
|
+
}
|
|
977
|
+
|
|
978
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
979
|
+
pipeline.context = decisions;
|
|
980
|
+
mul_mat_fast_pipelines[key] = pipeline;
|
|
981
|
+
return mul_mat_fast_pipelines[key];
|
|
982
|
+
}
|
|
983
|
+
|
|
984
|
+
webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
985
|
+
ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type,
|
|
986
|
+
.src1_type = context.src1->type };
|
|
987
|
+
|
|
988
|
+
auto it = mul_mat_legacy_pipelines.find(key);
|
|
989
|
+
if (it != mul_mat_legacy_pipelines.end()) {
|
|
990
|
+
return it->second;
|
|
991
|
+
}
|
|
992
|
+
|
|
993
|
+
std::vector<std::string> defines;
|
|
994
|
+
std::string variant = "mul_mat";
|
|
995
|
+
|
|
996
|
+
switch (context.src1->type) {
|
|
997
|
+
case GGML_TYPE_F32:
|
|
998
|
+
defines.push_back("SRC1_TYPE=f32");
|
|
999
|
+
variant += "_f32";
|
|
1000
|
+
break;
|
|
1001
|
+
case GGML_TYPE_F16:
|
|
1002
|
+
defines.push_back("SRC1_TYPE=f16");
|
|
1003
|
+
variant += "_f16";
|
|
1004
|
+
break;
|
|
1005
|
+
default:
|
|
1006
|
+
GGML_ABORT("Unsupported src1 type for mul_mat legacy shader");
|
|
1007
|
+
}
|
|
1008
|
+
|
|
1009
|
+
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
|
1010
|
+
const char * src0_name = src0_traits->type_name;
|
|
1011
|
+
|
|
1012
|
+
switch (context.src0->type) {
|
|
1013
|
+
case GGML_TYPE_F32:
|
|
1014
|
+
defines.push_back("SRC0_TYPE=f32");
|
|
1015
|
+
defines.push_back("FLOAT");
|
|
1016
|
+
variant += "_f32";
|
|
1017
|
+
break;
|
|
1018
|
+
case GGML_TYPE_F16:
|
|
1019
|
+
defines.push_back("SRC0_TYPE=f16");
|
|
1020
|
+
defines.push_back("FLOAT");
|
|
1021
|
+
variant += "_f16";
|
|
1022
|
+
break;
|
|
1023
|
+
default:
|
|
1024
|
+
{
|
|
1025
|
+
// quantized types
|
|
1026
|
+
std::string type_upper = src0_name;
|
|
1027
|
+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
|
1028
|
+
|
|
1029
|
+
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
|
|
1030
|
+
defines.push_back("BYTE_HELPERS");
|
|
1031
|
+
defines.push_back(type_upper + "_T");
|
|
1032
|
+
defines.push_back(type_upper);
|
|
1033
|
+
defines.push_back(type_upper + "_SCALE_MIN");
|
|
1034
|
+
defines.push_back(type_upper + "_TABLES");
|
|
1035
|
+
defines.push_back(type_upper + "_GRID");
|
|
1036
|
+
|
|
1037
|
+
variant += std::string("_") + src0_name;
|
|
1038
|
+
break;
|
|
1039
|
+
}
|
|
1040
|
+
}
|
|
1041
|
+
|
|
1042
|
+
auto processed = preprocessor.preprocess(wgsl_mul_mat, defines);
|
|
1043
|
+
|
|
1044
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1045
|
+
decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE;
|
|
1046
|
+
|
|
1047
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1048
|
+
pipeline.context = decisions;
|
|
1049
|
+
mul_mat_legacy_pipelines[key] = pipeline;
|
|
1050
|
+
return mul_mat_legacy_pipelines[key];
|
|
1051
|
+
}
|
|
1052
|
+
|
|
1053
|
+
webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1054
|
+
const bool is_unary = context.dst->op == GGML_OP_UNARY;
|
|
1055
|
+
const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
|
|
1056
|
+
ggml_webgpu_unary_pipeline_key key = {
|
|
1057
|
+
.type = context.dst->type,
|
|
1058
|
+
.op = op,
|
|
1059
|
+
.is_unary = is_unary,
|
|
1060
|
+
.inplace = context.inplace,
|
|
1061
|
+
};
|
|
1062
|
+
|
|
1063
|
+
auto it = unary_pipelines.find(key);
|
|
1064
|
+
if (it != unary_pipelines.end()) {
|
|
1065
|
+
return it->second;
|
|
1066
|
+
}
|
|
1067
|
+
|
|
1068
|
+
std::vector<std::string> defines;
|
|
1069
|
+
std::string variant =
|
|
1070
|
+
key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op);
|
|
1071
|
+
defines.push_back(variant);
|
|
1072
|
+
|
|
1073
|
+
switch (key.type) {
|
|
1074
|
+
case GGML_TYPE_F32:
|
|
1075
|
+
defines.push_back("TYPE_F32");
|
|
1076
|
+
variant += "_f32";
|
|
1077
|
+
break;
|
|
1078
|
+
case GGML_TYPE_F16:
|
|
1079
|
+
defines.push_back("TYPE_F16");
|
|
1080
|
+
variant += "_f16";
|
|
1081
|
+
break;
|
|
1082
|
+
default:
|
|
1083
|
+
GGML_ABORT("Unsupported type for unary shader");
|
|
1084
|
+
}
|
|
1085
|
+
|
|
1086
|
+
if (key.inplace) {
|
|
1087
|
+
defines.push_back("INPLACE");
|
|
1088
|
+
variant += "_inplace";
|
|
1089
|
+
}
|
|
1090
|
+
|
|
1091
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1092
|
+
|
|
1093
|
+
auto processed = preprocessor.preprocess(wgsl_unary, defines);
|
|
1094
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1095
|
+
decisions->wg_size = context.max_wg_size;
|
|
1096
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1097
|
+
pipeline.context = decisions;
|
|
1098
|
+
unary_pipelines[key] = pipeline;
|
|
1099
|
+
return unary_pipelines[key];
|
|
1100
|
+
}
|
|
1101
|
+
|
|
1102
|
+
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1103
|
+
ggml_webgpu_binary_pipeline_key key = {
|
|
1104
|
+
.type = context.dst->type,
|
|
1105
|
+
.op = context.dst->op,
|
|
1106
|
+
.inplace = context.inplace,
|
|
1107
|
+
.overlap = context.overlap,
|
|
1108
|
+
.src_overlap = context.src_overlap,
|
|
1109
|
+
};
|
|
1110
|
+
|
|
1111
|
+
auto it = binary_pipelines.find(key);
|
|
1112
|
+
if (it != binary_pipelines.end()) {
|
|
1113
|
+
return it->second;
|
|
1114
|
+
}
|
|
1115
|
+
|
|
1116
|
+
std::vector<std::string> defines;
|
|
1117
|
+
std::string op_name = ggml_op_name((ggml_op) key.op);
|
|
1118
|
+
std::string variant = op_name;
|
|
1119
|
+
|
|
1120
|
+
defines.push_back(std::string("OP_") + op_name);
|
|
1121
|
+
|
|
1122
|
+
switch (key.type) {
|
|
1123
|
+
case GGML_TYPE_F32:
|
|
1124
|
+
defines.push_back("TYPE_F32");
|
|
1125
|
+
variant += "_f32";
|
|
1126
|
+
break;
|
|
1127
|
+
case GGML_TYPE_F16:
|
|
1128
|
+
defines.push_back("TYPE_F16");
|
|
1129
|
+
variant += "_f16";
|
|
1130
|
+
break;
|
|
1131
|
+
default:
|
|
1132
|
+
GGML_ABORT("Unsupported type for binary shader");
|
|
1133
|
+
}
|
|
1134
|
+
|
|
1135
|
+
if (key.inplace) {
|
|
1136
|
+
defines.push_back("INPLACE");
|
|
1137
|
+
variant += "_inplace";
|
|
1138
|
+
} else if (key.overlap) {
|
|
1139
|
+
defines.push_back("OVERLAP");
|
|
1140
|
+
variant += "_overlap";
|
|
1141
|
+
} else if (key.src_overlap) {
|
|
1142
|
+
defines.push_back("SRC_OVERLAP");
|
|
1143
|
+
variant += "_src_overlap";
|
|
1144
|
+
}
|
|
1145
|
+
|
|
1146
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1147
|
+
|
|
1148
|
+
auto processed = preprocessor.preprocess(wgsl_binary, defines);
|
|
1149
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1150
|
+
decisions->wg_size = context.max_wg_size;
|
|
1151
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1152
|
+
pipeline.context = decisions;
|
|
1153
|
+
binary_pipelines[key] = pipeline;
|
|
1154
|
+
return binary_pipelines[key];
|
|
1155
|
+
}
|
|
1156
|
+
|
|
1157
|
+
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1158
|
+
ggml_webgpu_concat_pipeline_key key = {
|
|
1159
|
+
.type = context.dst->type,
|
|
1160
|
+
};
|
|
1161
|
+
|
|
1162
|
+
auto it = concat_pipelines.find(key);
|
|
1163
|
+
if (it != concat_pipelines.end()) {
|
|
1164
|
+
return it->second;
|
|
1165
|
+
}
|
|
1166
|
+
|
|
1167
|
+
std::vector<std::string> defines;
|
|
1168
|
+
std::string variant = "concat";
|
|
1169
|
+
|
|
1170
|
+
switch (key.type) {
|
|
1171
|
+
case GGML_TYPE_F32:
|
|
1172
|
+
defines.push_back("TYPE_F32");
|
|
1173
|
+
variant += "_f32";
|
|
1174
|
+
break;
|
|
1175
|
+
case GGML_TYPE_I32:
|
|
1176
|
+
defines.push_back("TYPE_I32");
|
|
1177
|
+
variant += "_i32";
|
|
1178
|
+
break;
|
|
1179
|
+
default:
|
|
1180
|
+
GGML_ABORT("Unsupported type for concat shader");
|
|
1181
|
+
}
|
|
1182
|
+
|
|
1183
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1184
|
+
|
|
1185
|
+
auto processed = preprocessor.preprocess(wgsl_concat, defines);
|
|
1186
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1187
|
+
decisions->wg_size = context.max_wg_size;
|
|
1188
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1189
|
+
pipeline.context = decisions;
|
|
1190
|
+
concat_pipelines[key] = pipeline;
|
|
1191
|
+
return concat_pipelines[key];
|
|
1192
|
+
}
|
|
1193
|
+
|
|
1194
|
+
webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1195
|
+
ggml_webgpu_repeat_pipeline_key key = {
|
|
1196
|
+
.type = context.dst->type,
|
|
1197
|
+
};
|
|
1198
|
+
|
|
1199
|
+
auto it = repeat_pipelines.find(key);
|
|
1200
|
+
if (it != repeat_pipelines.end()) {
|
|
1201
|
+
return it->second;
|
|
1202
|
+
}
|
|
1203
|
+
|
|
1204
|
+
std::vector<std::string> defines;
|
|
1205
|
+
std::string variant = "repeat";
|
|
1206
|
+
|
|
1207
|
+
switch (key.type) {
|
|
1208
|
+
case GGML_TYPE_F32:
|
|
1209
|
+
defines.push_back("TYPE_F32");
|
|
1210
|
+
variant += "_f32";
|
|
1211
|
+
break;
|
|
1212
|
+
case GGML_TYPE_I32:
|
|
1213
|
+
defines.push_back("TYPE_I32");
|
|
1214
|
+
variant += "_i32";
|
|
1215
|
+
break;
|
|
1216
|
+
case GGML_TYPE_I16:
|
|
1217
|
+
defines.push_back("TYPE_I16");
|
|
1218
|
+
variant += "_i16";
|
|
1219
|
+
break;
|
|
1220
|
+
default:
|
|
1221
|
+
GGML_ABORT("Unsupported type for repeat shader");
|
|
1222
|
+
}
|
|
1223
|
+
|
|
1224
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
|
1225
|
+
|
|
1226
|
+
auto processed = preprocessor.preprocess(wgsl_repeat, defines);
|
|
1227
|
+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
|
1228
|
+
decisions->wg_size = context.max_wg_size;
|
|
1229
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1230
|
+
pipeline.context = decisions;
|
|
1231
|
+
repeat_pipelines[key] = pipeline;
|
|
1232
|
+
return repeat_pipelines[key];
|
|
1233
|
+
}
|
|
1234
|
+
|
|
1235
|
+
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
|
1236
|
+
const bool has_mask = context.src3 != nullptr;
|
|
1237
|
+
const bool has_sinks = context.src4 != nullptr;
|
|
1238
|
+
|
|
1239
|
+
bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
|
|
1240
|
+
(context.src1->ne[1] % context.sg_mat_n == 0);
|
|
1241
|
+
|
|
1242
|
+
ggml_webgpu_flash_attn_pipeline_key key = {
|
|
1243
|
+
.kv_type = context.src1->type,
|
|
1244
|
+
.head_dim_qk = (uint32_t) context.src0->ne[0],
|
|
1245
|
+
.head_dim_v = (uint32_t) context.src2->ne[0],
|
|
1246
|
+
.kv_direct = kv_direct,
|
|
1247
|
+
.has_mask = has_mask,
|
|
1248
|
+
.has_sinks = has_sinks,
|
|
1249
|
+
.uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f,
|
|
1250
|
+
};
|
|
1251
|
+
|
|
1252
|
+
auto it = flash_attn_pipelines.find(key);
|
|
1253
|
+
if (it != flash_attn_pipelines.end()) {
|
|
1254
|
+
return it->second;
|
|
1255
|
+
}
|
|
1256
|
+
|
|
1257
|
+
std::vector<std::string> defines;
|
|
1258
|
+
std::string variant = "flash_attn";
|
|
1259
|
+
|
|
1260
|
+
switch (key.kv_type) {
|
|
1261
|
+
case GGML_TYPE_F32:
|
|
1262
|
+
defines.push_back("KV_F32");
|
|
1263
|
+
break;
|
|
1264
|
+
case GGML_TYPE_F16:
|
|
1265
|
+
defines.push_back("KV_F16");
|
|
1266
|
+
break;
|
|
1267
|
+
case GGML_TYPE_Q4_0:
|
|
1268
|
+
defines.push_back("KV_Q4_0");
|
|
1269
|
+
break;
|
|
1270
|
+
case GGML_TYPE_Q8_0:
|
|
1271
|
+
defines.push_back("KV_Q8_0");
|
|
1272
|
+
break;
|
|
1273
|
+
default:
|
|
1274
|
+
GGML_ABORT("Unsupported KV type for flash attention shader");
|
|
1275
|
+
}
|
|
1276
|
+
variant += std::string("_") + ggml_type_name(key.kv_type);
|
|
1277
|
+
|
|
1278
|
+
if (key.has_mask) {
|
|
1279
|
+
defines.push_back("MASK");
|
|
1280
|
+
variant += "_mask";
|
|
1281
|
+
}
|
|
1282
|
+
if (key.has_sinks) {
|
|
1283
|
+
defines.push_back("SINKS");
|
|
1284
|
+
variant += "_sinks";
|
|
1285
|
+
}
|
|
1286
|
+
if (key.uses_logit_softcap) {
|
|
1287
|
+
defines.push_back("LOGIT_SOFTCAP");
|
|
1288
|
+
variant += "_lgsc";
|
|
1289
|
+
}
|
|
1290
|
+
if (key.kv_direct) {
|
|
1291
|
+
defines.push_back("KV_DIRECT");
|
|
1292
|
+
variant += "_kvdirect";
|
|
1293
|
+
}
|
|
1294
|
+
|
|
1295
|
+
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
|
1296
|
+
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
|
1297
|
+
|
|
1298
|
+
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
|
1299
|
+
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
|
1300
|
+
|
|
1301
|
+
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
|
1302
|
+
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
|
1303
|
+
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
|
1304
|
+
|
|
1305
|
+
uint32_t q_tile = context.sg_mat_m;
|
|
1306
|
+
uint32_t kv_tile =
|
|
1307
|
+
std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k,
|
|
1308
|
+
context.wg_mem_limit_bytes, context.max_subgroup_size }),
|
|
1309
|
+
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
|
1310
|
+
if (key.kv_direct) {
|
|
1311
|
+
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
|
1312
|
+
kv_tile -= context.sg_mat_n;
|
|
1313
|
+
}
|
|
1314
|
+
}
|
|
1315
|
+
|
|
1316
|
+
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
|
|
1317
|
+
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
|
|
1318
|
+
|
|
1319
|
+
uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
|
1320
|
+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
|
1321
|
+
|
|
1322
|
+
auto processed = preprocessor.preprocess(wgsl_flash_attn, defines);
|
|
1323
|
+
auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
|
|
1324
|
+
decisions->q_tile = q_tile;
|
|
1325
|
+
decisions->kv_tile = kv_tile;
|
|
1326
|
+
decisions->wg_size = wg_size;
|
|
1327
|
+
|
|
1328
|
+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
|
1329
|
+
pipeline.context = decisions;
|
|
1330
|
+
flash_attn_pipelines[key] = pipeline;
|
|
1331
|
+
return flash_attn_pipelines[key];
|
|
1332
|
+
}
|
|
1333
|
+
|
|
1334
|
+
private:
|
|
1335
|
+
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
|
|
1336
|
+
std::string shader_code,
|
|
1337
|
+
std::string label) {
|
|
1338
|
+
wgpu::ShaderSourceWGSL shader_source;
|
|
1339
|
+
shader_source.code = shader_code.c_str();
|
|
1340
|
+
|
|
1341
|
+
wgpu::ShaderModuleDescriptor shader_desc;
|
|
1342
|
+
shader_desc.nextInChain = &shader_source;
|
|
1343
|
+
|
|
1344
|
+
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
|
|
1345
|
+
|
|
1346
|
+
wgpu::ComputePipelineDescriptor pipeline_desc;
|
|
1347
|
+
pipeline_desc.label = label.c_str();
|
|
1348
|
+
pipeline_desc.compute.module = shader_module;
|
|
1349
|
+
pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
|
|
1350
|
+
pipeline_desc.layout = nullptr; // nullptr means auto layout
|
|
1351
|
+
return { device.CreateComputePipeline(&pipeline_desc), label };
|
|
1352
|
+
}
|
|
1353
|
+
|
|
1354
|
+
static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
|
|
1355
|
+
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
|
1356
|
+
const size_t q_tile = context.sg_mat_m;
|
|
1357
|
+
const size_t base_q_bytes =
|
|
1358
|
+
(context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
|
1359
|
+
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
|
1360
|
+
size_t bytes_per_kv = 0;
|
|
1361
|
+
if (!context.key.kv_direct) {
|
|
1362
|
+
bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
|
|
1363
|
+
}
|
|
1364
|
+
if (context.key.has_mask) {
|
|
1365
|
+
bytes_per_kv += q_tile;
|
|
1366
|
+
}
|
|
1367
|
+
bytes_per_kv += q_tile;
|
|
1368
|
+
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
|
|
1369
|
+
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
|
|
1370
|
+
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
|
|
1371
|
+
}
|
|
1372
|
+
};
|
|
168
1373
|
|
|
169
1374
|
#endif // GGML_WEBGPU_SHADER_LIB_HPP
|