whispercpp 1.3.4 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +158 -44
- data/ext/extconf.rb +3 -2
- data/ext/ruby_whisper.c +34 -6
- data/ext/ruby_whisper.h +67 -0
- data/ext/ruby_whisper_context.c +236 -144
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +12 -13
- data/ext/ruby_whisper_params.c +47 -24
- data/ext/ruby_whisper_segment.c +84 -20
- data/ext/ruby_whisper_token.c +371 -0
- data/ext/ruby_whisper_transcribe.cpp +5 -2
- data/ext/ruby_whisper_vad_context.c +122 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +138 -0
- data/ext/ruby_whisper_vad_segments.c +105 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +129 -112
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +28 -15
- data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
- data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
- data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
- data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
- data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
- data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
- data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
- data/ext/sources/examples/talk-llama/llama-context.h +70 -23
- data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
- data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
- data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
- data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
- data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
- data/ext/sources/examples/talk-llama/llama-model.h +112 -18
- data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
- data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
- data/ext/sources/examples/talk-llama/llama.cpp +802 -21
- data/ext/sources/examples/talk-llama/llama.h +210 -39
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +704 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
- data/ext/sources/ggml/CMakeLists.txt +90 -56
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +5 -2
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +6 -0
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +14 -12
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +246 -21
- data/ext/sources/ggml/src/CMakeLists.txt +85 -11
- data/ext/sources/ggml/src/ggml-alloc.c +128 -50
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
- data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
- data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
- data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
- data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
- data/ext/sources/ggml/src/ggml-impl.h +129 -6
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
- data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- data/ext/sources/ggml/src/ggml.c +590 -64
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +106 -62
- data/ext/sources/tests/CMakeLists.txt +2 -2
- data/ext/sources/tests/test-vad-full.cpp +4 -2
- data/ext/sources/tests/test-vad.cpp +1 -1
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +162 -4
- data/test/test_context_params.rb +82 -0
- data/test/test_params.rb +16 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +81 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +100 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +27 -0
- data/whispercpp.gemspec +1 -1
- metadata +502 -37
- data/ext/sources/build-xcframework.sh +0 -571
- data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -7,37 +7,102 @@
|
|
|
7
7
|
|
|
8
8
|
#include "ggml-backend-impl.h"
|
|
9
9
|
#include "ggml-impl.h"
|
|
10
|
-
#include "ggml-
|
|
10
|
+
#include "ggml-webgpu-shader-lib.hpp"
|
|
11
|
+
|
|
12
|
+
#ifdef __EMSCRIPTEN__
|
|
13
|
+
# include <emscripten/emscripten.h>
|
|
14
|
+
#endif
|
|
11
15
|
|
|
12
16
|
#include <webgpu/webgpu_cpp.h>
|
|
13
17
|
|
|
18
|
+
#include <atomic>
|
|
14
19
|
#include <condition_variable>
|
|
20
|
+
#include <cstdint>
|
|
15
21
|
#include <cstring>
|
|
16
|
-
#
|
|
22
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
23
|
+
# include <iomanip>
|
|
24
|
+
#endif
|
|
25
|
+
#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)
|
|
26
|
+
# include <iostream>
|
|
27
|
+
#endif
|
|
28
|
+
#include <map>
|
|
29
|
+
#include <memory>
|
|
17
30
|
#include <mutex>
|
|
31
|
+
#include <optional>
|
|
18
32
|
#include <string>
|
|
33
|
+
#include <utility>
|
|
19
34
|
#include <vector>
|
|
20
35
|
|
|
36
|
+
#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
|
|
37
|
+
#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
|
|
38
|
+
|
|
39
|
+
// Return a rectangular grid of workgroups with minimal over-provisioned workgroups.
|
|
40
|
+
// Assumes that the total number of workgroups does not exceed max_per_dim^2.
|
|
41
|
+
static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) {
|
|
42
|
+
wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim));
|
|
43
|
+
wg_x = CEIL_DIV(total_wg, wg_y);
|
|
44
|
+
}
|
|
45
|
+
|
|
21
46
|
#ifdef GGML_WEBGPU_DEBUG
|
|
22
47
|
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
|
|
23
|
-
# define WEBGPU_DEBUG_BUF_ELEMS
|
|
48
|
+
# define WEBGPU_DEBUG_BUF_ELEMS 512
|
|
24
49
|
#else
|
|
25
50
|
# define WEBGPU_LOG_DEBUG(msg) ((void) 0)
|
|
26
51
|
#endif // GGML_WEBGPU_DEBUG
|
|
27
52
|
|
|
53
|
+
#ifdef GGML_WEBGPU_CPU_PROFILE
|
|
54
|
+
// total timing (aggregated)
|
|
55
|
+
# define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now();
|
|
56
|
+
|
|
57
|
+
# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) \
|
|
58
|
+
auto cpu_total_end_##id = std::chrono::high_resolution_clock::now(); \
|
|
59
|
+
double cpu_total_time_##id = \
|
|
60
|
+
std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \
|
|
61
|
+
(ctx)->cpu_time_ms[#id] += cpu_total_time_##id;
|
|
62
|
+
// fine-grained timing (not included in totals)
|
|
63
|
+
# define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();
|
|
64
|
+
|
|
65
|
+
# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) \
|
|
66
|
+
auto cpu_detail_end_##id = std::chrono::high_resolution_clock::now(); \
|
|
67
|
+
double cpu_detail_time_##id = \
|
|
68
|
+
std::chrono::duration<double, std::milli>(cpu_detail_end_##id - cpu_detail_start_##id).count(); \
|
|
69
|
+
(ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id;
|
|
70
|
+
#else
|
|
71
|
+
# define WEBGPU_CPU_PROFILE_TOTAL_START(id)
|
|
72
|
+
# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx)
|
|
73
|
+
# define WEBGPU_CPU_PROFILE_DETAIL_START(id)
|
|
74
|
+
# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx)
|
|
75
|
+
#endif // GGML_WEBGPU_CPU_PROFILE
|
|
76
|
+
|
|
77
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
78
|
+
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32
|
|
79
|
+
# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
|
|
80
|
+
#endif
|
|
81
|
+
|
|
28
82
|
/* Constants */
|
|
29
83
|
|
|
30
|
-
#define
|
|
31
|
-
#define
|
|
32
|
-
#define
|
|
84
|
+
#define WEBGPU_NUM_PARAM_BUFS 96u
|
|
85
|
+
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
|
|
86
|
+
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
|
|
87
|
+
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
|
|
88
|
+
// parameter buffer pool
|
|
89
|
+
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
|
|
33
90
|
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
|
|
34
|
-
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
|
|
35
91
|
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
|
|
36
|
-
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4
|
|
92
|
+
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
|
93
|
+
|
|
94
|
+
// For operations which process a row in parallel, this seems like a reasonable
|
|
95
|
+
// default
|
|
96
|
+
#define WEBGPU_ROW_SPLIT_WG_SIZE 64
|
|
97
|
+
|
|
98
|
+
// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to
|
|
99
|
+
// implementations so this can be removed, necessary only for get_rows right now
|
|
100
|
+
#define WEBGPU_MAX_WG_SIZE 288
|
|
37
101
|
|
|
38
102
|
/* End Constants */
|
|
39
103
|
|
|
40
|
-
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
|
|
104
|
+
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
|
|
105
|
+
// their locations.
|
|
41
106
|
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
|
|
42
107
|
|
|
43
108
|
// Always returns the base offset of a tensor, regardless of views.
|
|
@@ -57,14 +122,98 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
|
|
57
122
|
wgpu::BufferUsage usage,
|
|
58
123
|
const char * label);
|
|
59
124
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
wgpu::Buffer
|
|
125
|
+
// Holds a pool of parameter buffers for WebGPU operations
|
|
126
|
+
struct webgpu_buf_pool {
|
|
127
|
+
std::vector<wgpu::Buffer> free;
|
|
128
|
+
|
|
129
|
+
// The pool must be synchronized because
|
|
130
|
+
// 1. The memset pool is shared globally by every ggml buffer,
|
|
131
|
+
// since allocating a pool per ggml buffer would consume too much memory.
|
|
132
|
+
// 2. For the per-thread buffer pools in webgpu_context,
|
|
133
|
+
// buffers are allocated and freed in Dawn callbacks,
|
|
134
|
+
// which can run on a different thread than the calling thread.
|
|
135
|
+
std::mutex mutex;
|
|
136
|
+
std::condition_variable cv;
|
|
137
|
+
size_t cur_pool_size;
|
|
138
|
+
size_t max_pool_size;
|
|
139
|
+
wgpu::Device device;
|
|
140
|
+
wgpu::BufferUsage dev_buf_usage;
|
|
141
|
+
size_t buf_size;
|
|
142
|
+
bool should_grow;
|
|
143
|
+
|
|
144
|
+
void init(wgpu::Device device,
|
|
145
|
+
int num_bufs,
|
|
146
|
+
size_t buf_size,
|
|
147
|
+
wgpu::BufferUsage dev_buf_usage,
|
|
148
|
+
bool should_grow = false,
|
|
149
|
+
size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
|
|
150
|
+
this->max_pool_size = max_pool_size;
|
|
151
|
+
this->cur_pool_size = num_bufs;
|
|
152
|
+
this->device = device;
|
|
153
|
+
this->dev_buf_usage = dev_buf_usage;
|
|
154
|
+
this->buf_size = buf_size;
|
|
155
|
+
this->should_grow = should_grow;
|
|
156
|
+
for (int i = 0; i < num_bufs; i++) {
|
|
157
|
+
wgpu::Buffer dev_buf;
|
|
158
|
+
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
|
|
159
|
+
free.push_back(dev_buf);
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
wgpu::Buffer alloc_bufs() {
|
|
164
|
+
std::unique_lock<std::mutex> lock(mutex);
|
|
165
|
+
if (!free.empty()) {
|
|
166
|
+
wgpu::Buffer buf = free.back();
|
|
167
|
+
free.pop_back();
|
|
168
|
+
return buf;
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
// Try growing the pool if no free buffers
|
|
172
|
+
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
|
|
173
|
+
cur_pool_size++;
|
|
174
|
+
wgpu::Buffer dev_buf;
|
|
175
|
+
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
|
|
176
|
+
|
|
177
|
+
if (!dev_buf) {
|
|
178
|
+
GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
|
|
179
|
+
}
|
|
180
|
+
return dev_buf;
|
|
181
|
+
}
|
|
182
|
+
cv.wait(lock, [this] { return !free.empty(); });
|
|
183
|
+
wgpu::Buffer buf = free.back();
|
|
184
|
+
free.pop_back();
|
|
185
|
+
return buf;
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
void free_bufs(std::vector<wgpu::Buffer> bufs) {
|
|
189
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
190
|
+
free.insert(free.end(), bufs.begin(), bufs.end());
|
|
191
|
+
cv.notify_all();
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
void cleanup() {
|
|
195
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
196
|
+
for (auto & buf : free) {
|
|
197
|
+
if (buf) {
|
|
198
|
+
buf.Destroy();
|
|
199
|
+
}
|
|
200
|
+
}
|
|
201
|
+
free.clear();
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
~webgpu_buf_pool() { this->cleanup(); }
|
|
205
|
+
};
|
|
206
|
+
|
|
207
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
208
|
+
struct webgpu_gpu_profile_bufs {
|
|
209
|
+
wgpu::Buffer host_buf;
|
|
210
|
+
wgpu::Buffer dev_buf;
|
|
211
|
+
wgpu::QuerySet query_set;
|
|
63
212
|
};
|
|
64
213
|
|
|
65
214
|
// Holds a pool of parameter buffers for WebGPU operations
|
|
66
|
-
struct
|
|
67
|
-
std::vector<
|
|
215
|
+
struct webgpu_gpu_profile_buf_pool {
|
|
216
|
+
std::vector<webgpu_gpu_profile_bufs> free;
|
|
68
217
|
|
|
69
218
|
std::mutex mutex;
|
|
70
219
|
|
|
@@ -78,21 +227,28 @@ struct webgpu_buf_pool {
|
|
|
78
227
|
for (int i = 0; i < num_bufs; i++) {
|
|
79
228
|
wgpu::Buffer host_buf;
|
|
80
229
|
wgpu::Buffer dev_buf;
|
|
81
|
-
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "
|
|
82
|
-
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "
|
|
83
|
-
|
|
230
|
+
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf");
|
|
231
|
+
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf");
|
|
232
|
+
// Create a query set for 2 timestamps
|
|
233
|
+
wgpu::QuerySetDescriptor ts_query_set_desc = {};
|
|
234
|
+
|
|
235
|
+
ts_query_set_desc.type = wgpu::QueryType::Timestamp;
|
|
236
|
+
ts_query_set_desc.count = 2;
|
|
237
|
+
wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc);
|
|
238
|
+
|
|
239
|
+
free.push_back({ host_buf, dev_buf, ts_query_set });
|
|
84
240
|
}
|
|
85
241
|
}
|
|
86
242
|
|
|
87
|
-
|
|
243
|
+
webgpu_gpu_profile_bufs alloc_bufs() {
|
|
88
244
|
std::unique_lock<std::mutex> lock(mutex);
|
|
89
245
|
cv.wait(lock, [this] { return !free.empty(); });
|
|
90
|
-
|
|
246
|
+
webgpu_gpu_profile_bufs bufs = free.back();
|
|
91
247
|
free.pop_back();
|
|
92
248
|
return bufs;
|
|
93
249
|
}
|
|
94
250
|
|
|
95
|
-
void free_bufs(std::vector<
|
|
251
|
+
void free_bufs(std::vector<webgpu_gpu_profile_bufs> bufs) {
|
|
96
252
|
std::lock_guard<std::mutex> lock(mutex);
|
|
97
253
|
free.insert(free.end(), bufs.begin(), bufs.end());
|
|
98
254
|
cv.notify_all();
|
|
@@ -103,101 +259,163 @@ struct webgpu_buf_pool {
|
|
|
103
259
|
for (auto & bufs : free) {
|
|
104
260
|
bufs.host_buf.Destroy();
|
|
105
261
|
bufs.dev_buf.Destroy();
|
|
262
|
+
bufs.query_set.Destroy();
|
|
106
263
|
}
|
|
107
264
|
free.clear();
|
|
108
265
|
}
|
|
266
|
+
|
|
267
|
+
~webgpu_gpu_profile_buf_pool() { this->cleanup(); }
|
|
109
268
|
};
|
|
269
|
+
#endif
|
|
110
270
|
|
|
111
|
-
|
|
112
|
-
|
|
271
|
+
struct webgpu_command {
|
|
272
|
+
uint32_t num_kernels;
|
|
273
|
+
wgpu::CommandBuffer commands;
|
|
274
|
+
std::vector<wgpu::Buffer> params_bufs;
|
|
275
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
276
|
+
webgpu_gpu_profile_bufs timestamp_query_bufs;
|
|
277
|
+
std::string pipeline_name;
|
|
278
|
+
#endif
|
|
279
|
+
};
|
|
280
|
+
|
|
281
|
+
struct webgpu_capabilities {
|
|
282
|
+
wgpu::Limits limits;
|
|
283
|
+
bool supports_subgroup_matrix = false;
|
|
284
|
+
|
|
285
|
+
uint32_t sg_mat_m = 0;
|
|
286
|
+
uint32_t sg_mat_n = 0;
|
|
287
|
+
uint32_t sg_mat_k = 0;
|
|
288
|
+
|
|
289
|
+
uint32_t subgroup_size = 0;
|
|
290
|
+
uint32_t max_subgroup_size = 0;
|
|
291
|
+
size_t memset_bytes_per_thread;
|
|
292
|
+
};
|
|
293
|
+
|
|
294
|
+
// Stores global webgpu members
|
|
295
|
+
struct webgpu_global_context_struct {
|
|
113
296
|
wgpu::Instance instance;
|
|
114
297
|
wgpu::Adapter adapter;
|
|
115
298
|
wgpu::Device device;
|
|
116
299
|
wgpu::Queue queue;
|
|
117
|
-
wgpu::Limits limits;
|
|
118
|
-
|
|
119
|
-
// Separate this out from limits since on some Metal systems, the limit returned by
|
|
120
|
-
// querying the limits is higher than the actual allowed maximum.
|
|
121
|
-
uint32_t max_wg_size_x;
|
|
122
300
|
|
|
301
|
+
webgpu_capabilities capabilities;
|
|
302
|
+
// Shared buffer to move data from device to host
|
|
303
|
+
wgpu::Buffer get_tensor_staging_buf;
|
|
304
|
+
// Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
|
|
123
305
|
std::recursive_mutex mutex;
|
|
124
306
|
|
|
125
|
-
webgpu_buf_pool
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
wgpu::ComputePipeline memset_pipeline;
|
|
129
|
-
wgpu::ComputePipeline mul_mat_pipeline[30][2];
|
|
130
|
-
wgpu::ComputePipeline set_rows_pipeline;
|
|
131
|
-
wgpu::ComputePipeline get_rows_pipeline[30];
|
|
132
|
-
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
|
|
133
|
-
wgpu::ComputePipeline cpy_pipeline;
|
|
134
|
-
wgpu::ComputePipeline add_pipeline[2];
|
|
135
|
-
wgpu::ComputePipeline add_ip_pipeline[2];
|
|
136
|
-
wgpu::ComputePipeline mul_pipeline[2];
|
|
137
|
-
wgpu::ComputePipeline mul_ip_pipeline[2];
|
|
138
|
-
wgpu::ComputePipeline rms_norm_pipeline;
|
|
139
|
-
wgpu::ComputePipeline rms_norm_ip_pipeline;
|
|
140
|
-
|
|
141
|
-
size_t memset_bytes_per_thread;
|
|
142
|
-
|
|
143
|
-
// Staging buffer for reading data from the GPU
|
|
144
|
-
wgpu::Buffer get_tensor_staging_buf;
|
|
307
|
+
webgpu_buf_pool memset_buf_pool;
|
|
308
|
+
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
|
|
145
309
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
//
|
|
150
|
-
std::
|
|
151
|
-
|
|
152
|
-
std::vector<webgpu_pool_bufs> staged_set_row_error_bufs;
|
|
310
|
+
#ifdef GGML_WEBGPU_CPU_PROFILE
|
|
311
|
+
// Profiling: labeled CPU time in ms (total)
|
|
312
|
+
std::unordered_map<std::string, double> cpu_time_ms;
|
|
313
|
+
// Profiling: detailed CPU time in ms
|
|
314
|
+
std::unordered_map<std::string, double> cpu_detail_ms;
|
|
315
|
+
#endif
|
|
153
316
|
|
|
154
|
-
|
|
317
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
318
|
+
// Profiling: per-shader GPU time in ms
|
|
319
|
+
std::unordered_map<std::string, double> shader_gpu_time_ms;
|
|
320
|
+
// Profiling: pool of timestamp query buffers (one per operation)
|
|
321
|
+
webgpu_gpu_profile_buf_pool timestamp_query_buf_pool;
|
|
322
|
+
#endif
|
|
155
323
|
|
|
156
324
|
#ifdef GGML_WEBGPU_DEBUG
|
|
157
325
|
wgpu::Buffer debug_host_buf;
|
|
158
326
|
wgpu::Buffer debug_dev_buf;
|
|
159
327
|
#endif
|
|
328
|
+
|
|
329
|
+
~webgpu_global_context_struct() {
|
|
330
|
+
if (this->get_tensor_staging_buf) {
|
|
331
|
+
this->get_tensor_staging_buf.Destroy();
|
|
332
|
+
this->get_tensor_staging_buf = nullptr;
|
|
333
|
+
}
|
|
334
|
+
#ifdef GGML_WEBGPU_DEBUG
|
|
335
|
+
if (this->debug_host_buf) {
|
|
336
|
+
this->debug_host_buf.Destroy();
|
|
337
|
+
this->debug_host_buf = nullptr;
|
|
338
|
+
}
|
|
339
|
+
if (this->debug_dev_buf) {
|
|
340
|
+
this->debug_dev_buf.Destroy();
|
|
341
|
+
this->debug_dev_buf = nullptr;
|
|
342
|
+
}
|
|
343
|
+
#endif
|
|
344
|
+
}
|
|
345
|
+
};
|
|
346
|
+
|
|
347
|
+
typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
|
|
348
|
+
|
|
349
|
+
struct webgpu_submission {
|
|
350
|
+
wgpu::FutureWaitInfo submit_done;
|
|
351
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
352
|
+
std::vector<wgpu::FutureWaitInfo> profile_futures;
|
|
353
|
+
#endif
|
|
354
|
+
};
|
|
355
|
+
|
|
356
|
+
// All the base objects needed to run operations on a WebGPU device
|
|
357
|
+
struct webgpu_context_struct {
|
|
358
|
+
// Points to global instances owned by ggml_backend_webgpu_reg_context
|
|
359
|
+
webgpu_global_context global_ctx;
|
|
360
|
+
|
|
361
|
+
std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
|
|
362
|
+
|
|
363
|
+
webgpu_buf_pool param_buf_pool;
|
|
364
|
+
wgpu::Buffer set_rows_dev_error_buf;
|
|
365
|
+
wgpu::Buffer set_rows_host_error_buf;
|
|
366
|
+
|
|
367
|
+
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
|
368
|
+
|
|
369
|
+
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
|
|
370
|
+
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
|
|
371
|
+
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
|
|
372
|
+
|
|
373
|
+
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
|
|
374
|
+
|
|
375
|
+
size_t memset_bytes_per_thread;
|
|
160
376
|
};
|
|
161
377
|
|
|
162
378
|
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
|
|
163
379
|
|
|
380
|
+
// Metadata required for the ggml backend registration/discovery interface
|
|
164
381
|
struct ggml_backend_webgpu_reg_context {
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
382
|
+
// Since the Instance is a global entrypoint into the WebGPU API, it lives here
|
|
383
|
+
webgpu_global_context webgpu_global_ctx;
|
|
384
|
+
size_t device_count;
|
|
385
|
+
const char * name;
|
|
168
386
|
};
|
|
169
387
|
|
|
388
|
+
// Per-device struct for the global logical device interface
|
|
170
389
|
struct ggml_backend_webgpu_device_context {
|
|
171
|
-
|
|
172
|
-
std::string
|
|
173
|
-
std::string
|
|
390
|
+
webgpu_global_context webgpu_global_ctx;
|
|
391
|
+
std::string device_name;
|
|
392
|
+
std::string device_desc;
|
|
174
393
|
};
|
|
175
394
|
|
|
395
|
+
// Per-thread data required to actually run WebGPU operations in a backend instance
|
|
176
396
|
struct ggml_backend_webgpu_context {
|
|
177
397
|
webgpu_context webgpu_ctx;
|
|
178
398
|
std::string name;
|
|
179
399
|
};
|
|
180
400
|
|
|
401
|
+
// Per-thread data related to buffers
|
|
181
402
|
struct ggml_backend_webgpu_buffer_context {
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
buffer(std::move(buf))
|
|
403
|
+
wgpu::Buffer buffer;
|
|
404
|
+
std::string label;
|
|
405
|
+
webgpu_global_context global_ctx;
|
|
406
|
+
|
|
407
|
+
ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl, webgpu_global_context global_ctx_) :
|
|
408
|
+
buffer(std::move(buf)),
|
|
409
|
+
label(std::move(lbl)),
|
|
410
|
+
global_ctx(std::move(global_ctx_)) {}
|
|
188
411
|
};
|
|
189
412
|
|
|
190
|
-
/* End struct definitions */
|
|
191
|
-
|
|
192
413
|
/* WebGPU object initializations */
|
|
193
414
|
|
|
194
|
-
static
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
const std::vector<wgpu::ConstantEntry> & constants = {}) {
|
|
199
|
-
WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
|
|
200
|
-
|
|
415
|
+
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
|
|
416
|
+
const char * shader_code,
|
|
417
|
+
const char * label,
|
|
418
|
+
const std::vector<wgpu::ConstantEntry> & constants = {}) {
|
|
201
419
|
wgpu::ShaderSourceWGSL shader_source;
|
|
202
420
|
shader_source.code = shader_code;
|
|
203
421
|
|
|
@@ -215,7 +433,7 @@ static void ggml_webgpu_create_pipeline(wgpu::Device &
|
|
|
215
433
|
pipeline_desc.compute.constants = constants.data();
|
|
216
434
|
pipeline_desc.compute.constantCount = constants.size();
|
|
217
435
|
}
|
|
218
|
-
|
|
436
|
+
return { device.CreateComputePipeline(&pipeline_desc), label };
|
|
219
437
|
}
|
|
220
438
|
|
|
221
439
|
static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
|
@@ -223,8 +441,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
|
|
223
441
|
size_t size,
|
|
224
442
|
wgpu::BufferUsage usage,
|
|
225
443
|
const char * label) {
|
|
226
|
-
WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()");
|
|
227
|
-
|
|
228
444
|
wgpu::BufferDescriptor buffer_desc;
|
|
229
445
|
buffer_desc.size = size;
|
|
230
446
|
buffer_desc.usage = usage;
|
|
@@ -239,88 +455,113 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
|
|
239
455
|
|
|
240
456
|
/** WebGPU Actions */
|
|
241
457
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
|
|
259
|
-
ctx->callback_futures.clear();
|
|
458
|
+
static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {
|
|
459
|
+
switch (status) {
|
|
460
|
+
case wgpu::WaitStatus::Success:
|
|
461
|
+
return true;
|
|
462
|
+
case wgpu::WaitStatus::TimedOut:
|
|
463
|
+
if (allow_timeout) {
|
|
464
|
+
return false;
|
|
465
|
+
}
|
|
466
|
+
GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n");
|
|
467
|
+
return false;
|
|
468
|
+
case wgpu::WaitStatus::Error:
|
|
469
|
+
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
|
470
|
+
return false;
|
|
471
|
+
default:
|
|
472
|
+
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
|
|
473
|
+
return false;
|
|
260
474
|
}
|
|
261
475
|
}
|
|
262
476
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
477
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
478
|
+
static void ggml_backend_webgpu_erase_completed_futures(std::vector<wgpu::FutureWaitInfo> & futures) {
|
|
479
|
+
futures.erase(std::remove_if(futures.begin(), futures.end(),
|
|
480
|
+
[](const wgpu::FutureWaitInfo & info) { return info.completed; }),
|
|
481
|
+
futures.end());
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx,
|
|
485
|
+
std::vector<wgpu::FutureWaitInfo> & futures,
|
|
486
|
+
bool block) {
|
|
487
|
+
if (futures.empty()) {
|
|
268
488
|
return;
|
|
269
489
|
}
|
|
270
|
-
ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
|
|
271
490
|
|
|
272
|
-
|
|
273
|
-
if (
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
491
|
+
uint64_t timeout_ms = block ? UINT64_MAX : 0;
|
|
492
|
+
if (block) {
|
|
493
|
+
while (!futures.empty()) {
|
|
494
|
+
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
|
495
|
+
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
|
|
496
|
+
ggml_backend_webgpu_erase_completed_futures(futures);
|
|
497
|
+
}
|
|
498
|
+
}
|
|
499
|
+
} else {
|
|
500
|
+
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
|
501
|
+
if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
|
|
502
|
+
ggml_backend_webgpu_erase_completed_futures(futures);
|
|
278
503
|
}
|
|
279
|
-
wgpu::CommandBuffer commands = encoder.Finish();
|
|
280
|
-
ctx->queue.Submit(1, &commands);
|
|
281
504
|
}
|
|
505
|
+
}
|
|
506
|
+
#endif
|
|
282
507
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
508
|
+
// Wait for the queue to finish processing all submitted work
|
|
509
|
+
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
|
|
510
|
+
std::vector<webgpu_submission> & subs,
|
|
511
|
+
bool block = true) {
|
|
512
|
+
// If we have too many in-flight submissions, wait on the oldest one first.
|
|
513
|
+
if (subs.empty()) {
|
|
514
|
+
return;
|
|
515
|
+
}
|
|
516
|
+
while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
|
|
517
|
+
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);
|
|
518
|
+
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
|
|
519
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
520
|
+
ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
|
|
521
|
+
#endif
|
|
522
|
+
subs.erase(subs.begin());
|
|
523
|
+
}
|
|
524
|
+
}
|
|
286
525
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
[ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
|
291
|
-
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
|
292
|
-
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
|
|
293
|
-
}
|
|
294
|
-
// Free the staged buffers
|
|
295
|
-
ctx->param_buf_pool.free_bufs(staged_param_bufs);
|
|
296
|
-
});
|
|
297
|
-
ctx->callback_futures.push_back({ p_f });
|
|
526
|
+
if (subs.empty()) {
|
|
527
|
+
return;
|
|
528
|
+
}
|
|
298
529
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
530
|
+
if (block) {
|
|
531
|
+
for (auto & sub : subs) {
|
|
532
|
+
while (!sub.submit_done.completed) {
|
|
533
|
+
auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);
|
|
534
|
+
ggml_backend_webgpu_handle_wait_status(waitStatus);
|
|
535
|
+
}
|
|
536
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
537
|
+
ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);
|
|
538
|
+
#endif
|
|
539
|
+
}
|
|
540
|
+
subs.clear();
|
|
541
|
+
} else {
|
|
542
|
+
// Poll each submit future once and remove completed submissions.
|
|
543
|
+
for (auto sub = subs.begin(); sub != subs.end();) {
|
|
544
|
+
auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
|
|
545
|
+
ggml_backend_webgpu_handle_wait_status(waitStatus, true);
|
|
546
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
547
|
+
ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
|
|
548
|
+
if (sub->submit_done.completed && sub->profile_futures.empty()) {
|
|
549
|
+
#else
|
|
550
|
+
if (sub->submit_done.completed) {
|
|
551
|
+
#endif
|
|
552
|
+
sub = subs.erase(sub);
|
|
553
|
+
} else {
|
|
554
|
+
++sub;
|
|
555
|
+
}
|
|
556
|
+
}
|
|
316
557
|
}
|
|
317
558
|
}
|
|
318
559
|
|
|
319
|
-
static void ggml_backend_webgpu_map_buffer(
|
|
320
|
-
wgpu::Buffer &
|
|
321
|
-
wgpu::MapMode
|
|
322
|
-
size_t
|
|
323
|
-
size_t
|
|
560
|
+
static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
|
|
561
|
+
wgpu::Buffer & buffer,
|
|
562
|
+
wgpu::MapMode mode,
|
|
563
|
+
size_t offset,
|
|
564
|
+
size_t size) {
|
|
324
565
|
ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
|
|
325
566
|
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
|
326
567
|
if (status != wgpu::MapAsyncStatus::Success) {
|
|
@@ -335,100 +576,178 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
|
|
|
335
576
|
// This function adds debugging information to shaders, as WebGPU does not support printing directly.
|
|
336
577
|
// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
|
|
337
578
|
// debug statements in the shader, and then call this function after encoding the commands and submitting them.
|
|
338
|
-
static void ggml_backend_webgpu_debug(
|
|
339
|
-
ggml_backend_webgpu_submit_queue(ctx);
|
|
579
|
+
static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
|
|
340
580
|
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
|
341
581
|
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
|
|
342
582
|
wgpu::CommandBuffer commands = encoder.Finish();
|
|
343
583
|
ctx->queue.Submit(1, &commands);
|
|
344
|
-
|
|
345
584
|
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
|
|
346
|
-
const
|
|
347
|
-
std::cout << "debug
|
|
348
|
-
for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
|
|
349
|
-
std::cout << " " << i << ": " << debug_data[i];
|
|
350
|
-
}
|
|
351
|
-
std::cout << "\n";
|
|
585
|
+
const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
|
|
586
|
+
std::cout << "debug[0]: " << debug_data[0] << "\n";
|
|
352
587
|
ctx->debug_host_buf.Unmap();
|
|
353
588
|
}
|
|
354
589
|
#endif
|
|
355
590
|
|
|
356
|
-
static
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
}
|
|
591
|
+
static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx,
|
|
592
|
+
std::vector<webgpu_command> & commands,
|
|
593
|
+
webgpu_buf_pool & param_buf_pool) {
|
|
594
|
+
std::vector<wgpu::CommandBuffer> command_buffers;
|
|
595
|
+
std::vector<wgpu::Buffer> params_bufs;
|
|
596
|
+
webgpu_submission submission;
|
|
597
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
598
|
+
std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
|
|
599
|
+
#endif
|
|
600
|
+
|
|
601
|
+
for (const auto & command : commands) {
|
|
602
|
+
command_buffers.push_back(command.commands);
|
|
603
|
+
params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
|
|
604
|
+
}
|
|
605
|
+
ctx->queue.Submit(command_buffers.size(), command_buffers.data());
|
|
606
|
+
|
|
607
|
+
wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
|
|
608
|
+
wgpu::CallbackMode::AllowSpontaneous,
|
|
609
|
+
[¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
|
610
|
+
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
|
611
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
|
|
612
|
+
}
|
|
613
|
+
// Free the staged buffers
|
|
614
|
+
param_buf_pool.free_bufs(params_bufs);
|
|
615
|
+
});
|
|
616
|
+
submission.submit_done = { p_f };
|
|
370
617
|
|
|
371
|
-
|
|
618
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
619
|
+
for (const auto & command : commands) {
|
|
620
|
+
auto label = command.pipeline_name;
|
|
621
|
+
auto ts_bufs = command.timestamp_query_bufs;
|
|
372
622
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
623
|
+
wgpu::Future f = ts_bufs.host_buf.MapAsync(
|
|
624
|
+
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
|
|
625
|
+
[ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
|
626
|
+
if (status != wgpu::MapAsyncStatus::Success) {
|
|
627
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
|
|
628
|
+
} else {
|
|
629
|
+
const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange();
|
|
630
|
+
// WebGPU timestamps are in ns; convert to ms
|
|
631
|
+
double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
|
|
632
|
+
ctx->shader_gpu_time_ms[label] += elapsed_ms;
|
|
633
|
+
}
|
|
634
|
+
// We can't unmap in here due to WebGPU reentrancy limitations.
|
|
635
|
+
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
|
|
636
|
+
});
|
|
637
|
+
submission.profile_futures.push_back({ f });
|
|
638
|
+
}
|
|
639
|
+
#endif
|
|
640
|
+
return submission;
|
|
641
|
+
}
|
|
378
642
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
643
|
+
static webgpu_command ggml_backend_webgpu_build_multi(
|
|
644
|
+
webgpu_global_context & ctx,
|
|
645
|
+
webgpu_buf_pool & param_buf_pool,
|
|
646
|
+
const std::vector<webgpu_pipeline> & pipelines,
|
|
647
|
+
const std::vector<std::vector<uint32_t>> & params_list,
|
|
648
|
+
const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
|
|
649
|
+
const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list) {
|
|
650
|
+
GGML_ASSERT(pipelines.size() == params_list.size());
|
|
651
|
+
GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
|
|
652
|
+
GGML_ASSERT(pipelines.size() == workgroups_list.size());
|
|
653
|
+
|
|
654
|
+
std::vector<wgpu::Buffer> params_bufs_list;
|
|
655
|
+
std::vector<wgpu::BindGroup> bind_groups;
|
|
656
|
+
|
|
657
|
+
for (size_t i = 0; i < pipelines.size(); i++) {
|
|
658
|
+
wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();
|
|
659
|
+
|
|
660
|
+
std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
|
|
661
|
+
uint32_t params_binding_num = entries.size();
|
|
662
|
+
entries.push_back(
|
|
663
|
+
{ .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });
|
|
664
|
+
|
|
665
|
+
wgpu::BindGroupDescriptor bind_group_desc;
|
|
666
|
+
bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
|
|
667
|
+
bind_group_desc.entryCount = entries.size();
|
|
668
|
+
bind_group_desc.entries = entries.data();
|
|
669
|
+
bind_group_desc.label = pipelines[i].name.c_str();
|
|
670
|
+
bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc));
|
|
671
|
+
|
|
672
|
+
params_bufs_list.push_back(params_bufs);
|
|
385
673
|
}
|
|
386
|
-
wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
|
|
387
674
|
|
|
388
675
|
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
|
389
|
-
|
|
676
|
+
for (size_t i = 0; i < params_bufs_list.size(); i++) {
|
|
677
|
+
ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
|
|
678
|
+
}
|
|
679
|
+
|
|
680
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
681
|
+
webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
|
|
682
|
+
if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
|
|
683
|
+
ts_bufs.host_buf.Unmap();
|
|
684
|
+
}
|
|
685
|
+
|
|
686
|
+
wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set,
|
|
687
|
+
.beginningOfPassWriteIndex = 0,
|
|
688
|
+
.endOfPassWriteIndex = 1 };
|
|
689
|
+
wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes };
|
|
690
|
+
wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc);
|
|
691
|
+
#else
|
|
390
692
|
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
693
|
+
#endif
|
|
694
|
+
for (size_t i = 0; i < pipelines.size(); i++) {
|
|
695
|
+
pass.SetPipeline(pipelines[i].pipeline);
|
|
696
|
+
pass.SetBindGroup(0, bind_groups[i]);
|
|
697
|
+
pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1);
|
|
698
|
+
}
|
|
394
699
|
pass.End();
|
|
700
|
+
|
|
701
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
702
|
+
encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
|
|
703
|
+
encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
|
|
704
|
+
#endif
|
|
705
|
+
|
|
395
706
|
wgpu::CommandBuffer commands = encoder.Finish();
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
707
|
+
webgpu_command result = {};
|
|
708
|
+
result.commands = commands;
|
|
709
|
+
result.params_bufs = params_bufs_list;
|
|
710
|
+
result.num_kernels = pipelines.size();
|
|
711
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
712
|
+
result.timestamp_query_bufs = ts_bufs;
|
|
713
|
+
// TODO: handle multiple pipeline names
|
|
714
|
+
result.pipeline_name = pipelines.front().name;
|
|
715
|
+
#endif
|
|
716
|
+
return result;
|
|
717
|
+
}
|
|
718
|
+
|
|
719
|
+
static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & ctx,
|
|
720
|
+
webgpu_buf_pool & param_buf_pool,
|
|
721
|
+
webgpu_pipeline & pipeline,
|
|
722
|
+
std::vector<uint32_t> params,
|
|
723
|
+
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
|
724
|
+
uint32_t wg_x,
|
|
725
|
+
uint32_t wg_y = 1) {
|
|
726
|
+
return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
|
|
727
|
+
{
|
|
728
|
+
pipeline
|
|
729
|
+
},
|
|
730
|
+
{ std::move(params) }, { std::move(bind_group_entries) },
|
|
731
|
+
{ { wg_x, wg_y } });
|
|
418
732
|
}
|
|
419
733
|
|
|
420
|
-
static void ggml_backend_webgpu_buffer_memset(
|
|
421
|
-
wgpu::Buffer &
|
|
422
|
-
uint32_t
|
|
423
|
-
size_t
|
|
424
|
-
size_t
|
|
734
|
+
static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
|
|
735
|
+
wgpu::Buffer & buf,
|
|
736
|
+
uint32_t value,
|
|
737
|
+
size_t offset,
|
|
738
|
+
size_t size) {
|
|
425
739
|
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
|
|
426
740
|
std::vector<wgpu::BindGroupEntry> entries = {
|
|
427
741
|
{ .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
|
|
428
742
|
};
|
|
429
|
-
size_t bytes_per_wg =
|
|
430
|
-
uint32_t wg_x = (
|
|
431
|
-
|
|
743
|
+
size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread;
|
|
744
|
+
uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
|
|
745
|
+
|
|
746
|
+
webgpu_command command =
|
|
747
|
+
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
|
|
748
|
+
std::vector<webgpu_command> commands = { command };
|
|
749
|
+
std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
|
|
750
|
+
ggml_backend_webgpu_wait(ctx, sub);
|
|
432
751
|
}
|
|
433
752
|
|
|
434
753
|
/** End WebGPU Actions */
|
|
@@ -444,8 +763,48 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
|
|
|
444
763
|
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
|
|
445
764
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
|
|
446
765
|
|
|
447
|
-
|
|
448
|
-
|
|
766
|
+
#ifdef GGML_WEBGPU_CPU_PROFILE
|
|
767
|
+
std::cout << "\n[ggml_webgpu cpu profiling summary]\n";
|
|
768
|
+
double total_cpu = 0.0;
|
|
769
|
+
for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
|
|
770
|
+
total_cpu += kv.second;
|
|
771
|
+
}
|
|
772
|
+
std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n";
|
|
773
|
+
std::cout << "ggml_webgpu: cpu breakdown:\n";
|
|
774
|
+
for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
|
|
775
|
+
double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
|
|
776
|
+
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
|
|
777
|
+
}
|
|
778
|
+
if (ctx->webgpu_ctx->global_ctx->cpu_detail_ms.size() > 0) {
|
|
779
|
+
std::cout << "ggml_webgpu: cpu detailed breakdown:\n";
|
|
780
|
+
}
|
|
781
|
+
for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_detail_ms) {
|
|
782
|
+
double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
|
|
783
|
+
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
|
|
784
|
+
}
|
|
785
|
+
#endif
|
|
786
|
+
|
|
787
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
788
|
+
std::cout << "\n[ggml_webgpu gpu profiling summary]\n";
|
|
789
|
+
double total_gpu = 0.0;
|
|
790
|
+
for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
|
|
791
|
+
total_gpu += kv.second;
|
|
792
|
+
}
|
|
793
|
+
std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n";
|
|
794
|
+
std::cout << "\nggml_webgpu: gpu breakdown:\n";
|
|
795
|
+
for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
|
|
796
|
+
double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
|
|
797
|
+
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2)
|
|
798
|
+
<< pct << "%)\n";
|
|
799
|
+
}
|
|
800
|
+
#endif
|
|
801
|
+
|
|
802
|
+
#if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE)
|
|
803
|
+
std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
|
|
804
|
+
#endif
|
|
805
|
+
|
|
806
|
+
delete ctx;
|
|
807
|
+
delete backend;
|
|
449
808
|
}
|
|
450
809
|
|
|
451
810
|
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
|
|
@@ -457,19 +816,18 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
|
|
|
457
816
|
return ctx->buffer;
|
|
458
817
|
}
|
|
459
818
|
|
|
460
|
-
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
|
|
819
|
+
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
|
|
461
820
|
size_t offset = ggml_webgpu_tensor_offset(t);
|
|
462
|
-
return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
|
821
|
+
return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
|
|
463
822
|
}
|
|
464
823
|
|
|
465
|
-
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
|
|
824
|
+
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
|
|
466
825
|
size_t offset = ggml_webgpu_tensor_offset(t);
|
|
467
|
-
return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
|
826
|
+
return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
|
|
468
827
|
}
|
|
469
828
|
|
|
470
829
|
static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
|
|
471
|
-
return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t)
|
|
472
|
-
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
|
|
830
|
+
return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
|
473
831
|
}
|
|
474
832
|
|
|
475
833
|
// Used to determine if two tensors are the same for in-place operations
|
|
@@ -478,7 +836,31 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
|
|
|
478
836
|
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
|
|
479
837
|
}
|
|
480
838
|
|
|
481
|
-
|
|
839
|
+
// Used to determine if two tensors share the same buffer and their byte ranges overlap,
|
|
840
|
+
static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
|
|
841
|
+
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
|
|
842
|
+
ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
|
|
843
|
+
ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
|
|
844
|
+
}
|
|
845
|
+
|
|
846
|
+
struct binary_overlap_flags {
|
|
847
|
+
bool inplace; // src0 == dst
|
|
848
|
+
bool overlap; // src1 == dst
|
|
849
|
+
bool src_overlap;
|
|
850
|
+
};
|
|
851
|
+
|
|
852
|
+
static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
|
|
853
|
+
ggml_tensor * src1,
|
|
854
|
+
ggml_tensor * dst) {
|
|
855
|
+
binary_overlap_flags flags = {};
|
|
856
|
+
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
|
|
857
|
+
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
|
|
858
|
+
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
|
|
859
|
+
|
|
860
|
+
return flags;
|
|
861
|
+
}
|
|
862
|
+
|
|
863
|
+
static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
482
864
|
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
|
483
865
|
|
|
484
866
|
std::vector<uint32_t> params = {
|
|
@@ -489,8 +871,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
|
|
|
489
871
|
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
|
490
872
|
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
491
873
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
492
|
-
// Logical
|
|
493
|
-
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t)
|
|
874
|
+
// Logical shapes
|
|
875
|
+
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
|
|
876
|
+
(uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
|
|
494
877
|
};
|
|
495
878
|
|
|
496
879
|
std::vector<wgpu::BindGroupEntry> entries = {
|
|
@@ -504,36 +887,49 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
|
|
|
504
887
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
|
505
888
|
};
|
|
506
889
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
890
|
+
uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
|
|
891
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
|
|
892
|
+
params, entries, wg_x);
|
|
510
893
|
}
|
|
511
894
|
|
|
512
|
-
static
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
}
|
|
895
|
+
static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
896
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
897
|
+
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
|
898
|
+
};
|
|
517
899
|
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
900
|
+
webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx);
|
|
901
|
+
|
|
902
|
+
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
|
903
|
+
|
|
904
|
+
const uint32_t ne = (uint32_t) ggml_nelements(dst);
|
|
522
905
|
|
|
523
906
|
std::vector<uint32_t> params = {
|
|
907
|
+
ne,
|
|
524
908
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
525
|
-
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
|
526
909
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
527
|
-
//
|
|
528
|
-
(uint32_t) (src->nb[
|
|
529
|
-
(uint32_t) (src->nb[
|
|
530
|
-
(uint32_t) (
|
|
531
|
-
(uint32_t) (
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
(uint32_t) src->ne[
|
|
535
|
-
|
|
536
|
-
(uint32_t)
|
|
910
|
+
// Strides (in elements)
|
|
911
|
+
(uint32_t) (src->nb[0] / ggml_type_size(src->type)),
|
|
912
|
+
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
|
913
|
+
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
|
914
|
+
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
|
915
|
+
// Shapes
|
|
916
|
+
(uint32_t) src->ne[0],
|
|
917
|
+
(uint32_t) src->ne[1],
|
|
918
|
+
(uint32_t) src->ne[2],
|
|
919
|
+
(uint32_t) src->ne[3],
|
|
920
|
+
(uint32_t) dst->ne[0],
|
|
921
|
+
(uint32_t) dst->ne[1],
|
|
922
|
+
(uint32_t) dst->ne[2],
|
|
923
|
+
(uint32_t) dst->ne[3],
|
|
924
|
+
// Pad sizes
|
|
925
|
+
(uint32_t) ggml_get_op_params_i32(dst, 0),
|
|
926
|
+
(uint32_t) ggml_get_op_params_i32(dst, 1),
|
|
927
|
+
(uint32_t) ggml_get_op_params_i32(dst, 2),
|
|
928
|
+
(uint32_t) ggml_get_op_params_i32(dst, 3),
|
|
929
|
+
(uint32_t) ggml_get_op_params_i32(dst, 4),
|
|
930
|
+
(uint32_t) ggml_get_op_params_i32(dst, 5),
|
|
931
|
+
(uint32_t) ggml_get_op_params_i32(dst, 6),
|
|
932
|
+
(uint32_t) ggml_get_op_params_i32(dst, 7),
|
|
537
933
|
};
|
|
538
934
|
|
|
539
935
|
std::vector<wgpu::BindGroupEntry> entries = {
|
|
@@ -542,26 +938,36 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
|
|
542
938
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
543
939
|
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
|
544
940
|
{ .binding = 1,
|
|
545
|
-
.buffer = ggml_webgpu_tensor_buf(idx),
|
|
546
|
-
.offset = ggml_webgpu_tensor_align_offset(ctx, idx),
|
|
547
|
-
.size = ggml_webgpu_tensor_binding_size(ctx, idx) },
|
|
548
|
-
{ .binding = 2,
|
|
549
941
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
550
942
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
551
|
-
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
|
552
|
-
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
|
|
943
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
|
553
944
|
};
|
|
554
945
|
|
|
555
|
-
|
|
556
|
-
|
|
946
|
+
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
|
947
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
|
948
|
+
}
|
|
557
949
|
|
|
558
|
-
|
|
559
|
-
|
|
950
|
+
static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
|
951
|
+
ggml_tensor * src,
|
|
952
|
+
ggml_tensor * idx,
|
|
953
|
+
ggml_tensor * dst) {
|
|
954
|
+
// For set rows specifically, we need to check if src and idx are empty
|
|
955
|
+
// tensors.
|
|
956
|
+
if (ggml_is_empty(src) || ggml_is_empty(idx)) {
|
|
957
|
+
return std::nullopt;
|
|
958
|
+
}
|
|
560
959
|
|
|
561
|
-
|
|
562
|
-
|
|
960
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
961
|
+
.src0 = src,
|
|
962
|
+
.src1 = idx,
|
|
963
|
+
.dst = dst,
|
|
964
|
+
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
|
965
|
+
};
|
|
966
|
+
|
|
967
|
+
webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx);
|
|
968
|
+
|
|
969
|
+
auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
|
|
563
970
|
|
|
564
|
-
static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
|
|
565
971
|
std::vector<uint32_t> params = {
|
|
566
972
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
567
973
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
|
@@ -572,8 +978,8 @@ static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
|
|
572
978
|
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
|
|
573
979
|
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
574
980
|
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
575
|
-
// Shape of
|
|
576
|
-
(uint32_t)
|
|
981
|
+
// Shape of src
|
|
982
|
+
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
|
|
577
983
|
// Shape of idx
|
|
578
984
|
(uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
|
|
579
985
|
};
|
|
@@ -593,43 +999,177 @@ static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
|
|
593
999
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
|
594
1000
|
};
|
|
595
1001
|
|
|
596
|
-
|
|
597
|
-
|
|
1002
|
+
if (decisions->i64_idx) {
|
|
1003
|
+
entries.push_back({ .binding = 3,
|
|
1004
|
+
.buffer = ctx->set_rows_dev_error_buf,
|
|
1005
|
+
.offset = 0,
|
|
1006
|
+
.size = ctx->set_rows_dev_error_buf.GetSize() });
|
|
1007
|
+
}
|
|
598
1008
|
|
|
599
|
-
|
|
600
|
-
if (
|
|
601
|
-
|
|
1009
|
+
uint32_t threads;
|
|
1010
|
+
if (decisions->vec4) {
|
|
1011
|
+
threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
|
|
1012
|
+
} else {
|
|
1013
|
+
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
|
|
602
1014
|
}
|
|
603
|
-
|
|
1015
|
+
uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
|
|
1016
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);
|
|
604
1017
|
}
|
|
605
1018
|
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
(uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
|
|
1019
|
+
// Workgroup size is a common constant
|
|
1020
|
+
static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
|
|
1021
|
+
std::vector<wgpu::ConstantEntry> constants(1);
|
|
1022
|
+
constants[0].key = "wg_size";
|
|
1023
|
+
constants[0].value = wg_size;
|
|
1024
|
+
return constants;
|
|
1025
|
+
}
|
|
1026
|
+
|
|
1027
|
+
static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
|
1028
|
+
ggml_tensor * src,
|
|
1029
|
+
ggml_tensor * idx,
|
|
1030
|
+
ggml_tensor * dst) {
|
|
1031
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
1032
|
+
.src0 = src,
|
|
1033
|
+
.src1 = nullptr,
|
|
1034
|
+
.dst = dst,
|
|
1035
|
+
.max_wg_size = WEBGPU_MAX_WG_SIZE,
|
|
624
1036
|
};
|
|
625
1037
|
|
|
1038
|
+
webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
|
|
1039
|
+
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
|
1040
|
+
|
|
1041
|
+
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
1042
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
|
1043
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1044
|
+
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
|
1045
|
+
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
|
1046
|
+
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
|
1047
|
+
(uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
|
|
1048
|
+
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
|
|
1049
|
+
(uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
|
|
1050
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
1051
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
1052
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
1053
|
+
(uint32_t) dst->ne[0],
|
|
1054
|
+
(uint32_t) dst->ne[1],
|
|
1055
|
+
(uint32_t) dst->ne[2],
|
|
1056
|
+
(uint32_t) dst->ne[3],
|
|
1057
|
+
(uint32_t) (idx->ne[1]),
|
|
1058
|
+
(uint32_t) (idx->ne[2]) };
|
|
1059
|
+
|
|
626
1060
|
std::vector<wgpu::BindGroupEntry> entries = {
|
|
627
1061
|
{ .binding = 0,
|
|
628
|
-
.buffer = ggml_webgpu_tensor_buf(
|
|
629
|
-
.offset = ggml_webgpu_tensor_align_offset(ctx,
|
|
630
|
-
.size = ggml_webgpu_tensor_binding_size(ctx,
|
|
1062
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
1063
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
1064
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
|
631
1065
|
{ .binding = 1,
|
|
632
|
-
.buffer = ggml_webgpu_tensor_buf(
|
|
1066
|
+
.buffer = ggml_webgpu_tensor_buf(idx),
|
|
1067
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, idx),
|
|
1068
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, idx) },
|
|
1069
|
+
{ .binding = 2,
|
|
1070
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1071
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1072
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
|
1073
|
+
};
|
|
1074
|
+
|
|
1075
|
+
uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
|
|
1076
|
+
|
|
1077
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
|
1078
|
+
}
|
|
1079
|
+
|
|
1080
|
+
static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|
1081
|
+
ggml_tensor * src0,
|
|
1082
|
+
ggml_tensor * src1,
|
|
1083
|
+
ggml_tensor * dst) {
|
|
1084
|
+
// Determine if this is a mat-vec operation
|
|
1085
|
+
bool is_vec = (dst->ne[1] == 1);
|
|
1086
|
+
|
|
1087
|
+
// Determine if we should use fast path
|
|
1088
|
+
bool use_fast = false;
|
|
1089
|
+
switch (src1->type) {
|
|
1090
|
+
case GGML_TYPE_F16:
|
|
1091
|
+
use_fast = (src0->type == GGML_TYPE_F16);
|
|
1092
|
+
break;
|
|
1093
|
+
case GGML_TYPE_F32:
|
|
1094
|
+
// TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
|
|
1095
|
+
switch (src0->type) {
|
|
1096
|
+
case GGML_TYPE_F32:
|
|
1097
|
+
case GGML_TYPE_F16:
|
|
1098
|
+
case GGML_TYPE_Q4_0:
|
|
1099
|
+
case GGML_TYPE_Q4_1:
|
|
1100
|
+
case GGML_TYPE_Q5_0:
|
|
1101
|
+
case GGML_TYPE_Q5_1:
|
|
1102
|
+
case GGML_TYPE_Q8_0:
|
|
1103
|
+
case GGML_TYPE_Q8_1:
|
|
1104
|
+
case GGML_TYPE_Q6_K:
|
|
1105
|
+
use_fast = true;
|
|
1106
|
+
break;
|
|
1107
|
+
case GGML_TYPE_Q2_K:
|
|
1108
|
+
case GGML_TYPE_Q3_K:
|
|
1109
|
+
case GGML_TYPE_Q4_K:
|
|
1110
|
+
case GGML_TYPE_Q5_K:
|
|
1111
|
+
// we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
|
|
1112
|
+
use_fast = !is_vec;
|
|
1113
|
+
break;
|
|
1114
|
+
default:
|
|
1115
|
+
break;
|
|
1116
|
+
}
|
|
1117
|
+
break;
|
|
1118
|
+
default:
|
|
1119
|
+
break;
|
|
1120
|
+
}
|
|
1121
|
+
|
|
1122
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
1123
|
+
.src0 = src0,
|
|
1124
|
+
.src1 = src1,
|
|
1125
|
+
.dst = dst,
|
|
1126
|
+
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
|
1127
|
+
.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix,
|
|
1128
|
+
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
|
|
1129
|
+
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
|
|
1130
|
+
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
|
|
1131
|
+
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
|
|
1132
|
+
};
|
|
1133
|
+
|
|
1134
|
+
// Get or create pipeline
|
|
1135
|
+
webgpu_pipeline pipeline;
|
|
1136
|
+
|
|
1137
|
+
if (use_fast && is_vec) {
|
|
1138
|
+
pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx);
|
|
1139
|
+
} else if (use_fast) {
|
|
1140
|
+
pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
|
|
1141
|
+
} else {
|
|
1142
|
+
pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx);
|
|
1143
|
+
}
|
|
1144
|
+
|
|
1145
|
+
// Build params
|
|
1146
|
+
std::vector<uint32_t> params = {
|
|
1147
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
|
1148
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
|
1149
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1150
|
+
(uint32_t) dst->ne[0],
|
|
1151
|
+
(uint32_t) dst->ne[1],
|
|
1152
|
+
(uint32_t) src0->ne[0],
|
|
1153
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
|
1154
|
+
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
|
1155
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
|
1156
|
+
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
|
1157
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
|
1158
|
+
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
|
1159
|
+
(uint32_t) src0->ne[2],
|
|
1160
|
+
(uint32_t) src0->ne[3],
|
|
1161
|
+
(uint32_t) (src1->ne[2] / src0->ne[2]),
|
|
1162
|
+
(uint32_t) (src1->ne[3] / src0->ne[3])
|
|
1163
|
+
};
|
|
1164
|
+
|
|
1165
|
+
// Build bind group entries
|
|
1166
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1167
|
+
{ .binding = 0,
|
|
1168
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
1169
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
|
1170
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
|
1171
|
+
{ .binding = 1,
|
|
1172
|
+
.buffer = ggml_webgpu_tensor_buf(src1),
|
|
633
1173
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
|
634
1174
|
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
|
635
1175
|
{ .binding = 2,
|
|
@@ -638,23 +1178,281 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
|
|
|
638
1178
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
|
639
1179
|
};
|
|
640
1180
|
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
1181
|
+
// Calculate workgroup dimensions
|
|
1182
|
+
uint32_t wg_x = 1;
|
|
1183
|
+
uint32_t wg_y = 1;
|
|
1184
|
+
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
|
1185
|
+
|
|
1186
|
+
if (use_fast && is_vec) {
|
|
1187
|
+
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
|
1188
|
+
|
|
1189
|
+
uint32_t batches = dst->ne[2] * dst->ne[3];
|
|
1190
|
+
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
|
|
1191
|
+
uint32_t total_wg = output_groups * batches;
|
|
1192
|
+
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
|
1193
|
+
} else if (use_fast) {
|
|
1194
|
+
auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
|
1195
|
+
|
|
1196
|
+
// Fast-path tiled/subgroup calculations
|
|
1197
|
+
uint32_t wg_m;
|
|
1198
|
+
uint32_t wg_n;
|
|
1199
|
+
if (decisions->use_subgroup_matrix) {
|
|
1200
|
+
uint32_t wg_m_sg_tile =
|
|
1201
|
+
decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
|
|
1202
|
+
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
|
|
1203
|
+
uint32_t wg_n_sg_tile =
|
|
1204
|
+
decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n;
|
|
1205
|
+
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
|
|
1206
|
+
} else {
|
|
1207
|
+
uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m;
|
|
1208
|
+
uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n;
|
|
1209
|
+
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
|
|
1210
|
+
wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
|
|
1211
|
+
}
|
|
1212
|
+
uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
|
1213
|
+
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
|
1214
|
+
|
|
1215
|
+
} else { // legacy
|
|
1216
|
+
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
|
1217
|
+
uint32_t wg_size = decisions->wg_size;
|
|
1218
|
+
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
|
|
1219
|
+
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
|
1220
|
+
}
|
|
1221
|
+
|
|
1222
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
|
|
645
1223
|
}
|
|
646
1224
|
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
1225
|
+
#ifndef __EMSCRIPTEN__
|
|
1226
|
+
static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|
1227
|
+
ggml_tensor * Q,
|
|
1228
|
+
ggml_tensor * K,
|
|
1229
|
+
ggml_tensor * V,
|
|
1230
|
+
ggml_tensor * mask,
|
|
1231
|
+
ggml_tensor * sinks,
|
|
1232
|
+
ggml_tensor * dst) {
|
|
1233
|
+
float scale = *(float *) dst->op_params;
|
|
1234
|
+
float max_bias;
|
|
1235
|
+
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
1236
|
+
float logit_softcap;
|
|
1237
|
+
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
|
1238
|
+
if (logit_softcap != 0.0f) {
|
|
1239
|
+
scale /= logit_softcap;
|
|
1240
|
+
}
|
|
1241
|
+
float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
|
|
1242
|
+
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
|
1243
|
+
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
1244
|
+
|
|
1245
|
+
const int has_mask = (mask != nullptr);
|
|
1246
|
+
const int has_sinks = (sinks != nullptr);
|
|
1247
|
+
|
|
653
1248
|
std::vector<uint32_t> params = {
|
|
654
|
-
(uint32_t)
|
|
1249
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
|
|
1250
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
|
|
1251
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
|
|
1252
|
+
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
|
|
1253
|
+
has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
|
|
1254
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1255
|
+
(uint32_t) Q->ne[2], // number of heads
|
|
1256
|
+
(uint32_t) Q->ne[1], // sequence length (Q)
|
|
1257
|
+
(uint32_t) K->ne[1], // sequence length (K/V)
|
|
1258
|
+
(uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1
|
|
1259
|
+
(uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2
|
|
1260
|
+
(uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3
|
|
1261
|
+
(uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1
|
|
1262
|
+
(uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2
|
|
1263
|
+
(uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3
|
|
1264
|
+
(uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
|
|
1265
|
+
(uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
|
|
1266
|
+
(uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
|
|
1267
|
+
has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
|
|
1268
|
+
(uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
|
|
1269
|
+
*(uint32_t *) &scale, // scale (possibly adjusted for logit softcap)
|
|
1270
|
+
*(uint32_t *) &max_bias,
|
|
1271
|
+
*(uint32_t *) &logit_softcap,
|
|
1272
|
+
*(uint32_t *) &n_head_log2,
|
|
1273
|
+
*(uint32_t *) &m0,
|
|
1274
|
+
*(uint32_t *) &m1
|
|
1275
|
+
|
|
1276
|
+
};
|
|
1277
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1278
|
+
{ .binding = 0,
|
|
1279
|
+
.buffer = ggml_webgpu_tensor_buf(Q),
|
|
1280
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, Q),
|
|
1281
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, Q) },
|
|
1282
|
+
{ .binding = 1,
|
|
1283
|
+
.buffer = ggml_webgpu_tensor_buf(K),
|
|
1284
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, K),
|
|
1285
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, K) },
|
|
1286
|
+
{ .binding = 2,
|
|
1287
|
+
.buffer = ggml_webgpu_tensor_buf(V),
|
|
1288
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, V),
|
|
1289
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, V) }
|
|
1290
|
+
};
|
|
1291
|
+
uint32_t binding_index = 3;
|
|
1292
|
+
if (has_mask) {
|
|
1293
|
+
entries.push_back({ .binding = binding_index++,
|
|
1294
|
+
.buffer = ggml_webgpu_tensor_buf(mask),
|
|
1295
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
|
|
1296
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, mask) });
|
|
1297
|
+
}
|
|
1298
|
+
if (has_sinks) {
|
|
1299
|
+
entries.push_back({ .binding = binding_index++,
|
|
1300
|
+
.buffer = ggml_webgpu_tensor_buf(sinks),
|
|
1301
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
|
|
1302
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
|
|
1303
|
+
}
|
|
1304
|
+
entries.push_back({ .binding = binding_index++,
|
|
1305
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1306
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1307
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1308
|
+
|
|
1309
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
1310
|
+
.src0 = Q,
|
|
1311
|
+
.src1 = K,
|
|
1312
|
+
.src2 = V,
|
|
1313
|
+
.src3 = mask,
|
|
1314
|
+
.src4 = sinks,
|
|
1315
|
+
.dst = dst,
|
|
1316
|
+
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
|
1317
|
+
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
|
|
1318
|
+
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
|
|
1319
|
+
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
|
|
1320
|
+
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
|
|
1321
|
+
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
|
|
1322
|
+
};
|
|
1323
|
+
|
|
1324
|
+
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
|
|
1325
|
+
|
|
1326
|
+
auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
|
|
1327
|
+
|
|
1328
|
+
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
|
|
1329
|
+
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
|
1330
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
|
1331
|
+
}
|
|
1332
|
+
#endif
|
|
1333
|
+
|
|
1334
|
+
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
1335
|
+
bool is_unary = dst->op == GGML_OP_UNARY;
|
|
1336
|
+
bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
|
|
1337
|
+
|
|
1338
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
1339
|
+
.src0 = src,
|
|
1340
|
+
.src1 = nullptr,
|
|
1341
|
+
.dst = dst,
|
|
1342
|
+
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
|
1343
|
+
.inplace = inplace,
|
|
1344
|
+
};
|
|
1345
|
+
|
|
1346
|
+
webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
|
|
1347
|
+
|
|
1348
|
+
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
|
1349
|
+
|
|
1350
|
+
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
|
1351
|
+
|
|
1352
|
+
std::vector<uint32_t> params = { ne,
|
|
1353
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
1354
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1355
|
+
(uint32_t) (src->nb[0] / ggml_type_size(src->type)),
|
|
1356
|
+
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
|
1357
|
+
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
|
1358
|
+
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
|
1359
|
+
(uint32_t) src->ne[0],
|
|
1360
|
+
(uint32_t) src->ne[1],
|
|
1361
|
+
(uint32_t) src->ne[2] };
|
|
1362
|
+
|
|
1363
|
+
ggml_tensor * effective_src = src;
|
|
1364
|
+
if (is_unary) {
|
|
1365
|
+
ggml_unary_op unary_op = ggml_get_unary_op(dst);
|
|
1366
|
+
switch (unary_op) {
|
|
1367
|
+
case GGML_UNARY_OP_XIELU:
|
|
1368
|
+
{
|
|
1369
|
+
// Get float parameters and reinterpret their bit patterns as uint32_t
|
|
1370
|
+
// for passing through the params buffer
|
|
1371
|
+
float alpha_n = ggml_get_op_params_f32(dst, 1);
|
|
1372
|
+
float alpha_p = ggml_get_op_params_f32(dst, 2);
|
|
1373
|
+
float beta = ggml_get_op_params_f32(dst, 3);
|
|
1374
|
+
float eps = ggml_get_op_params_f32(dst, 4);
|
|
1375
|
+
params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
|
|
1376
|
+
params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
|
|
1377
|
+
params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
|
|
1378
|
+
params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
|
|
1379
|
+
break;
|
|
1380
|
+
}
|
|
1381
|
+
default:
|
|
1382
|
+
break;
|
|
1383
|
+
}
|
|
1384
|
+
} else if (dst->op == GGML_OP_CLAMP) {
|
|
1385
|
+
float clamp_min = ggml_get_op_params_f32(dst, 0);
|
|
1386
|
+
float clamp_max = ggml_get_op_params_f32(dst, 1);
|
|
1387
|
+
params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_min));
|
|
1388
|
+
params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_max));
|
|
1389
|
+
} else if (dst->op == GGML_OP_FILL) {
|
|
1390
|
+
float fill_val = ggml_get_op_params_f32(dst, 0);
|
|
1391
|
+
params.push_back(*reinterpret_cast<const uint32_t *>(&fill_val));
|
|
1392
|
+
effective_src = dst; // fill simply fills dst
|
|
1393
|
+
}
|
|
1394
|
+
|
|
1395
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1396
|
+
{ .binding = 0,
|
|
1397
|
+
.buffer = ggml_webgpu_tensor_buf(effective_src),
|
|
1398
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, effective_src),
|
|
1399
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, effective_src) },
|
|
1400
|
+
};
|
|
1401
|
+
if (!inplace) {
|
|
1402
|
+
entries.push_back({ .binding = 1,
|
|
1403
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1404
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1405
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1406
|
+
}
|
|
1407
|
+
|
|
1408
|
+
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
|
1409
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
|
1410
|
+
}
|
|
1411
|
+
|
|
1412
|
+
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
|
1413
|
+
ggml_tensor * src0,
|
|
1414
|
+
ggml_tensor * src1,
|
|
1415
|
+
ggml_tensor * dst) {
|
|
1416
|
+
binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
|
|
1417
|
+
|
|
1418
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
1419
|
+
.src0 = src0,
|
|
1420
|
+
.src1 = src1,
|
|
1421
|
+
.dst = dst,
|
|
1422
|
+
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
|
1423
|
+
.inplace = flags.inplace,
|
|
1424
|
+
.overlap = flags.overlap,
|
|
1425
|
+
.src_overlap = flags.src_overlap,
|
|
1426
|
+
};
|
|
1427
|
+
|
|
1428
|
+
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
|
|
1429
|
+
|
|
1430
|
+
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
|
1431
|
+
|
|
1432
|
+
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
|
1433
|
+
|
|
1434
|
+
size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);
|
|
1435
|
+
size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);
|
|
1436
|
+
|
|
1437
|
+
uint32_t offset_merged_src0 = 0;
|
|
1438
|
+
uint32_t offset_merged_src1 = 0;
|
|
1439
|
+
if (flags.src_overlap) {
|
|
1440
|
+
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
|
|
1441
|
+
offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
|
|
1442
|
+
offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
|
|
1443
|
+
}
|
|
1444
|
+
|
|
1445
|
+
std::vector<uint32_t> params = {
|
|
1446
|
+
ne,
|
|
655
1447
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
|
656
1448
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
|
657
1449
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1450
|
+
offset_merged_src0,
|
|
1451
|
+
offset_merged_src1,
|
|
1452
|
+
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
|
1453
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
|
1454
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
|
1455
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
|
658
1456
|
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
|
659
1457
|
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
|
660
1458
|
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
|
@@ -668,87 +1466,709 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
|
|
|
668
1466
|
(uint32_t) src1->ne[3],
|
|
669
1467
|
};
|
|
670
1468
|
|
|
671
|
-
std::vector<wgpu::BindGroupEntry> entries
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
1469
|
+
std::vector<wgpu::BindGroupEntry> entries;
|
|
1470
|
+
|
|
1471
|
+
if (flags.src_overlap) {
|
|
1472
|
+
size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
|
|
1473
|
+
size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0),
|
|
1474
|
+
src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));
|
|
1475
|
+
entries.push_back({
|
|
1476
|
+
.binding = 0,
|
|
1477
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
1478
|
+
.offset = merged_offset,
|
|
1479
|
+
.size = merged_end - merged_offset,
|
|
1480
|
+
});
|
|
1481
|
+
entries.push_back({
|
|
1482
|
+
.binding = 1,
|
|
1483
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1484
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1485
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst),
|
|
1486
|
+
});
|
|
1487
|
+
} else {
|
|
1488
|
+
entries.push_back({
|
|
1489
|
+
.binding = 0,
|
|
1490
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
1491
|
+
.offset = src0_webgpu_tensor_align_offset,
|
|
1492
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src0),
|
|
1493
|
+
});
|
|
1494
|
+
entries.push_back({
|
|
1495
|
+
.binding = 1,
|
|
1496
|
+
.buffer = ggml_webgpu_tensor_buf(src1),
|
|
1497
|
+
.offset = src1_webgpu_tensor_align_offset,
|
|
1498
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src1),
|
|
1499
|
+
});
|
|
1500
|
+
if (!flags.inplace && !flags.overlap) {
|
|
1501
|
+
entries.push_back({
|
|
1502
|
+
.binding = 2,
|
|
1503
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1504
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1505
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst),
|
|
1506
|
+
});
|
|
1507
|
+
}
|
|
1508
|
+
}
|
|
1509
|
+
|
|
1510
|
+
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
|
1511
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
|
1512
|
+
}
|
|
1513
|
+
|
|
1514
|
+
static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
|
|
1515
|
+
ggml_tensor * src0,
|
|
1516
|
+
ggml_tensor * src1,
|
|
1517
|
+
ggml_tensor * dst) {
|
|
1518
|
+
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
|
1519
|
+
uint32_t dim = (uint32_t) dst->op_params[0];
|
|
1520
|
+
|
|
1521
|
+
std::vector<uint32_t> params = {
|
|
1522
|
+
ne,
|
|
1523
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
|
1524
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
|
1525
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1526
|
+
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
|
1527
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
|
1528
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
|
1529
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
|
1530
|
+
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
|
1531
|
+
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
|
1532
|
+
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
|
1533
|
+
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
|
1534
|
+
(uint32_t) dst->ne[0],
|
|
1535
|
+
(uint32_t) dst->ne[1],
|
|
1536
|
+
(uint32_t) dst->ne[2],
|
|
1537
|
+
(uint32_t) dst->ne[3],
|
|
1538
|
+
dim,
|
|
1539
|
+
(uint32_t) src0->ne[dim]
|
|
1540
|
+
};
|
|
1541
|
+
|
|
1542
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1543
|
+
{ .binding = 0,
|
|
1544
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
1545
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
|
1546
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
|
1547
|
+
{ .binding = 1,
|
|
1548
|
+
.buffer = ggml_webgpu_tensor_buf(src1),
|
|
1549
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
|
1550
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
|
1551
|
+
{ .binding = 2,
|
|
1552
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1553
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1554
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
|
1555
|
+
};
|
|
1556
|
+
|
|
1557
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
1558
|
+
.src0 = src0,
|
|
1559
|
+
.src1 = src1,
|
|
1560
|
+
.dst = dst,
|
|
1561
|
+
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
|
1562
|
+
};
|
|
1563
|
+
|
|
1564
|
+
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
|
|
1565
|
+
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
|
1566
|
+
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
|
1567
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
|
1568
|
+
}
|
|
1569
|
+
|
|
1570
|
+
static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) {
|
|
1571
|
+
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
|
1572
|
+
|
|
1573
|
+
std::vector<uint32_t> params = { ne,
|
|
1574
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) /
|
|
1575
|
+
ggml_type_size(src0->type)),
|
|
1576
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1577
|
+
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
|
1578
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
|
1579
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
|
1580
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
|
1581
|
+
(uint32_t) (src0->ne[0]),
|
|
1582
|
+
(uint32_t) (src0->ne[1]),
|
|
1583
|
+
(uint32_t) (src0->ne[2]),
|
|
1584
|
+
(uint32_t) (src0->ne[3]),
|
|
1585
|
+
(uint32_t) (dst->ne[0]),
|
|
1586
|
+
(uint32_t) (dst->ne[1]),
|
|
1587
|
+
(uint32_t) (dst->ne[2]) };
|
|
1588
|
+
|
|
1589
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1590
|
+
{ .binding = 0,
|
|
1591
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
1592
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
|
1593
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
|
1594
|
+
{ .binding = 1,
|
|
1595
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1596
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1597
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
|
1598
|
+
};
|
|
1599
|
+
|
|
1600
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
1601
|
+
.src0 = src0,
|
|
1602
|
+
.dst = dst,
|
|
1603
|
+
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
|
1604
|
+
};
|
|
1605
|
+
|
|
1606
|
+
webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx);
|
|
1607
|
+
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
|
1608
|
+
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
|
1609
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
|
1610
|
+
}
|
|
1611
|
+
|
|
1612
|
+
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
1613
|
+
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
|
1614
|
+
|
|
1615
|
+
std::vector<uint32_t> params = {
|
|
1616
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
1617
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1618
|
+
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
|
1619
|
+
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
|
1620
|
+
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
|
1621
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
1622
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
1623
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
1624
|
+
(uint32_t) src->ne[0],
|
|
1625
|
+
(uint32_t) src->ne[1],
|
|
1626
|
+
(uint32_t) src->ne[2],
|
|
1627
|
+
(uint32_t) src->ne[3],
|
|
1628
|
+
*(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
|
|
1629
|
+
};
|
|
1630
|
+
|
|
1631
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1632
|
+
{ .binding = 0,
|
|
1633
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
1634
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
1635
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
|
|
1636
|
+
};
|
|
1637
|
+
if (!inplace) {
|
|
1638
|
+
entries.push_back({ .binding = 1,
|
|
1639
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1640
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1641
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1642
|
+
}
|
|
1643
|
+
|
|
1644
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
|
|
1645
|
+
entries, ggml_nrows(src));
|
|
1646
|
+
}
|
|
1647
|
+
|
|
1648
|
+
static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
|
|
1649
|
+
ggml_tensor * src0,
|
|
1650
|
+
ggml_tensor * src1,
|
|
1651
|
+
ggml_tensor * src2,
|
|
1652
|
+
ggml_tensor * dst) {
|
|
1653
|
+
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
|
|
1654
|
+
const int has_freq_factor = (src2 != nullptr);
|
|
1655
|
+
|
|
1656
|
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
1657
|
+
const int mode = ((int32_t *) dst->op_params)[2];
|
|
1658
|
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
1659
|
+
|
|
1660
|
+
float freq_base;
|
|
1661
|
+
float freq_scale;
|
|
1662
|
+
float ext_factor;
|
|
1663
|
+
float attn_factor;
|
|
1664
|
+
float beta_fast;
|
|
1665
|
+
float beta_slow;
|
|
1666
|
+
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
1667
|
+
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
1668
|
+
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
1669
|
+
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
1670
|
+
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
1671
|
+
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
1672
|
+
|
|
1673
|
+
int sections[4];
|
|
1674
|
+
memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
|
|
1675
|
+
|
|
1676
|
+
float theta_scale = powf(freq_base, -2.0f / n_dims);
|
|
1677
|
+
|
|
1678
|
+
float corr_dims[2];
|
|
1679
|
+
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
1680
|
+
|
|
1681
|
+
std::vector<uint32_t> params = {
|
|
1682
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
|
1683
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
|
1684
|
+
src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
|
|
1685
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1686
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
|
1687
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
|
1688
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
|
1689
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
1690
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
1691
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
1692
|
+
(uint32_t) ggml_nelements(src0) / 2,
|
|
1693
|
+
(uint32_t) src0->ne[0],
|
|
1694
|
+
(uint32_t) src0->ne[1],
|
|
1695
|
+
(uint32_t) src0->ne[2],
|
|
1696
|
+
(uint32_t) n_dims,
|
|
1697
|
+
(uint32_t) mode,
|
|
1698
|
+
*(uint32_t *) &theta_scale,
|
|
1699
|
+
*(uint32_t *) &attn_factor,
|
|
1700
|
+
*(uint32_t *) &freq_scale,
|
|
1701
|
+
*(uint32_t *) &ext_factor,
|
|
1702
|
+
*(uint32_t *) &corr_dims[0],
|
|
1703
|
+
*(uint32_t *) &corr_dims[1],
|
|
1704
|
+
(uint32_t) sections[0],
|
|
1705
|
+
(uint32_t) sections[1],
|
|
1706
|
+
(uint32_t) sections[2],
|
|
1707
|
+
(uint32_t) sections[3]
|
|
1708
|
+
};
|
|
1709
|
+
|
|
1710
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1711
|
+
{ .binding = 0,
|
|
1712
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
1713
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
|
1714
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
|
1715
|
+
{ .binding = 1,
|
|
1716
|
+
.buffer = ggml_webgpu_tensor_buf(src1),
|
|
1717
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
|
1718
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
|
|
1719
|
+
};
|
|
1720
|
+
uint32_t dst_binding = 2;
|
|
1721
|
+
if (has_freq_factor) {
|
|
1722
|
+
dst_binding = 3;
|
|
1723
|
+
entries.push_back({ .binding = 2,
|
|
1724
|
+
.buffer = ggml_webgpu_tensor_buf(src2),
|
|
1725
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
|
|
1726
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src2) });
|
|
1727
|
+
}
|
|
1728
|
+
if (!inplace) {
|
|
1729
|
+
entries.push_back({ .binding = dst_binding,
|
|
1730
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1731
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1732
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1733
|
+
}
|
|
1734
|
+
|
|
1735
|
+
webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
|
|
1736
|
+
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
|
|
1737
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
|
1738
|
+
}
|
|
1739
|
+
|
|
1740
|
+
static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
|
1741
|
+
const int split = (src1 != nullptr);
|
|
1742
|
+
|
|
1743
|
+
std::vector<uint32_t> params = {
|
|
1744
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
|
1745
|
+
src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
|
|
1746
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1747
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
|
1748
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
|
1749
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
|
1750
|
+
src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
|
|
1751
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
|
1752
|
+
src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
|
|
1753
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
|
1754
|
+
src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
|
|
1755
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
|
1756
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
1757
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
1758
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
1759
|
+
(uint32_t) ggml_nelements(dst),
|
|
1760
|
+
(uint32_t) dst->ne[0],
|
|
1761
|
+
(uint32_t) dst->ne[1],
|
|
1762
|
+
(uint32_t) dst->ne[2],
|
|
1763
|
+
(uint32_t) ((int32_t *) dst->op_params)[1], // swapped
|
|
1764
|
+
*(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
|
|
1765
|
+
*(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
|
|
1766
|
+
};
|
|
1767
|
+
|
|
1768
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1769
|
+
{ .binding = 0,
|
|
1770
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
1771
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
|
1772
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
|
1773
|
+
};
|
|
1774
|
+
uint32_t dst_binding = 1;
|
|
1775
|
+
if (split) {
|
|
1776
|
+
dst_binding = 2;
|
|
1777
|
+
entries.push_back({ .binding = 1,
|
|
1778
|
+
.buffer = ggml_webgpu_tensor_buf(src1),
|
|
1779
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
|
1780
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
|
|
1781
|
+
}
|
|
1782
|
+
entries.push_back({ .binding = dst_binding,
|
|
1783
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1784
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1785
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1786
|
+
|
|
1787
|
+
webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
|
|
1788
|
+
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
|
|
1789
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
|
1790
|
+
}
|
|
1791
|
+
|
|
1792
|
+
static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
1793
|
+
bool inplace = ggml_webgpu_tensor_equal(src, dst);
|
|
1794
|
+
|
|
1795
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
1796
|
+
.src0 = src,
|
|
1797
|
+
.src1 = nullptr,
|
|
1798
|
+
.dst = dst,
|
|
1799
|
+
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
|
1800
|
+
.inplace = inplace,
|
|
1801
|
+
};
|
|
1802
|
+
|
|
1803
|
+
webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx);
|
|
1804
|
+
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
|
1805
|
+
|
|
1806
|
+
// params unchanged
|
|
1807
|
+
std::vector<uint32_t> params = {
|
|
1808
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
1809
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1810
|
+
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
|
1811
|
+
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
|
1812
|
+
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
|
1813
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
1814
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
1815
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
1816
|
+
(uint32_t) ggml_nelements(dst),
|
|
1817
|
+
(uint32_t) src->ne[0],
|
|
1818
|
+
(uint32_t) src->ne[1],
|
|
1819
|
+
(uint32_t) src->ne[2],
|
|
1820
|
+
*(uint32_t *) dst->op_params, // scale
|
|
1821
|
+
*(uint32_t *) &dst->op_params[1] // bias
|
|
1822
|
+
};
|
|
1823
|
+
|
|
1824
|
+
// bindgroups unchanged
|
|
1825
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1826
|
+
{ .binding = 0,
|
|
1827
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
1828
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
1829
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
|
|
1830
|
+
};
|
|
1831
|
+
|
|
1832
|
+
if (!inplace) {
|
|
1833
|
+
entries.push_back({ .binding = 1,
|
|
1834
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1835
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1836
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1837
|
+
}
|
|
1838
|
+
|
|
1839
|
+
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
|
|
1840
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
|
1841
|
+
}
|
|
1842
|
+
|
|
1843
|
+
static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
|
|
1844
|
+
ggml_tensor * src0,
|
|
1845
|
+
ggml_tensor * src1,
|
|
1846
|
+
ggml_tensor * src2,
|
|
1847
|
+
ggml_tensor * dst) {
|
|
1848
|
+
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
|
|
1849
|
+
const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
|
|
1850
|
+
const int has_sink = (src2 != nullptr);
|
|
1851
|
+
float max_bias;
|
|
1852
|
+
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
1853
|
+
float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
|
|
1854
|
+
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
|
1855
|
+
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
1856
|
+
|
|
1857
|
+
std::vector<uint32_t> params = {
|
|
1858
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
|
1859
|
+
mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
|
|
1860
|
+
has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
|
|
1861
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1862
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
|
1863
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
|
1864
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
|
1865
|
+
mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
|
|
1866
|
+
mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
|
|
1867
|
+
mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
|
|
1868
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
1869
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
1870
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
1871
|
+
(uint32_t) ggml_nelements(dst),
|
|
1872
|
+
(uint32_t) src0->ne[0],
|
|
1873
|
+
(uint32_t) src0->ne[1],
|
|
1874
|
+
(uint32_t) src0->ne[2],
|
|
1875
|
+
mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
|
|
1876
|
+
mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
|
|
1877
|
+
*(uint32_t *) dst->op_params, // scale
|
|
1878
|
+
*(uint32_t *) &max_bias,
|
|
1879
|
+
*(uint32_t *) &n_head_log2,
|
|
1880
|
+
*(uint32_t *) &m0,
|
|
1881
|
+
*(uint32_t *) &m1
|
|
1882
|
+
};
|
|
1883
|
+
|
|
1884
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1885
|
+
{ .binding = 0,
|
|
1886
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
1887
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
|
1888
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) }
|
|
1889
|
+
};
|
|
1890
|
+
uint32_t binding_num = 1;
|
|
1891
|
+
if (mask_type < 2) {
|
|
1892
|
+
entries.push_back({ .binding = binding_num,
|
|
1893
|
+
.buffer = ggml_webgpu_tensor_buf(src1),
|
|
1894
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
|
1895
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
|
|
1896
|
+
binding_num++;
|
|
1897
|
+
}
|
|
1898
|
+
if (has_sink) {
|
|
1899
|
+
entries.push_back({ .binding = binding_num,
|
|
1900
|
+
.buffer = ggml_webgpu_tensor_buf(src2),
|
|
1901
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
|
|
1902
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src2) });
|
|
1903
|
+
binding_num++;
|
|
1904
|
+
}
|
|
1905
|
+
if (!inplace) {
|
|
1906
|
+
entries.push_back({ .binding = binding_num,
|
|
1907
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1908
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1909
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1910
|
+
}
|
|
1911
|
+
|
|
1912
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
|
|
1913
|
+
ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
|
|
1914
|
+
ggml_nrows(dst));
|
|
1915
|
+
}
|
|
1916
|
+
|
|
1917
|
+
static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
1918
|
+
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
1919
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1920
|
+
(uint32_t) src->ne[0] };
|
|
1921
|
+
|
|
1922
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1923
|
+
{ .binding = 0,
|
|
1924
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
1925
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
1926
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
|
1927
|
+
{ .binding = 1,
|
|
1928
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1929
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1930
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
|
1931
|
+
};
|
|
1932
|
+
|
|
1933
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
1934
|
+
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
|
1935
|
+
};
|
|
1936
|
+
|
|
1937
|
+
webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx);
|
|
1938
|
+
uint32_t wg_x = ggml_nelements(dst);
|
|
1939
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
|
1940
|
+
}
|
|
1941
|
+
|
|
1942
|
+
static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
1943
|
+
bool is_top_k = dst->op == GGML_OP_TOP_K;
|
|
1944
|
+
|
|
1945
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
1946
|
+
.src0 = src,
|
|
1947
|
+
.src1 = nullptr,
|
|
1948
|
+
.dst = dst,
|
|
1949
|
+
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
|
1950
|
+
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
|
|
1951
|
+
};
|
|
1952
|
+
|
|
1953
|
+
webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx);
|
|
1954
|
+
auto * argsort_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(argsort_pipeline.context.get());
|
|
1955
|
+
|
|
1956
|
+
webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx);
|
|
1957
|
+
|
|
1958
|
+
const uint32_t src_ne0 = (uint32_t) src->ne[0];
|
|
1959
|
+
const uint32_t nrows = (uint32_t) ggml_nrows(src);
|
|
1960
|
+
const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions->wg_size);
|
|
1961
|
+
const uint32_t block_size =
|
|
1962
|
+
is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size;
|
|
1963
|
+
uint32_t out_ne0 = src_ne0;
|
|
1964
|
+
if (is_top_k) {
|
|
1965
|
+
if (npr > 1) {
|
|
1966
|
+
const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size;
|
|
1967
|
+
out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size);
|
|
1968
|
+
} else {
|
|
1969
|
+
out_ne0 = block_size;
|
|
1970
|
+
}
|
|
1971
|
+
}
|
|
1972
|
+
|
|
1973
|
+
uint32_t merge_len = block_size;
|
|
1974
|
+
uint32_t merge_passes = 0;
|
|
1975
|
+
while (merge_len < out_ne0) {
|
|
1976
|
+
merge_len <<= 1;
|
|
1977
|
+
merge_passes++;
|
|
1978
|
+
}
|
|
1979
|
+
|
|
1980
|
+
const bool start_in_tmp = (merge_passes % 2) == 1;
|
|
1981
|
+
|
|
1982
|
+
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
|
|
1983
|
+
const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t);
|
|
1984
|
+
const size_t tmp_offset =
|
|
1985
|
+
ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
|
1986
|
+
const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
|
|
1987
|
+
const size_t dst_binding_size =
|
|
1988
|
+
ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
|
1989
|
+
|
|
1990
|
+
const uint32_t offset_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type));
|
|
1991
|
+
const uint32_t offset_dst = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type));
|
|
1992
|
+
const uint32_t offset_tmp = 0;
|
|
1993
|
+
const uint32_t stride_src1 = (uint32_t) (src->nb[1] / ggml_type_size(src->type));
|
|
1994
|
+
const uint32_t stride_src2 = (uint32_t) (src->nb[2] / ggml_type_size(src->type));
|
|
1995
|
+
const uint32_t stride_src3 = (uint32_t) (src->nb[3] / ggml_type_size(src->type));
|
|
1996
|
+
const uint32_t stride_idx1 = out_ne0;
|
|
1997
|
+
const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1];
|
|
1998
|
+
const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2];
|
|
1999
|
+
|
|
2000
|
+
std::vector<webgpu_pipeline> pipelines;
|
|
2001
|
+
std::vector<std::vector<uint32_t>> params_list;
|
|
2002
|
+
std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
|
|
2003
|
+
std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
|
|
2004
|
+
|
|
2005
|
+
const uint32_t init_offset = start_in_tmp ? offset_tmp : offset_dst;
|
|
2006
|
+
const size_t init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
|
|
2007
|
+
const size_t init_binding_size = start_in_tmp ? tmp_binding_size : dst_binding_size;
|
|
2008
|
+
|
|
2009
|
+
std::vector<uint32_t> init_params = {
|
|
2010
|
+
offset_src, init_offset, stride_src1, stride_src2, stride_src3, stride_idx1,
|
|
2011
|
+
stride_idx2, stride_idx3, src_ne0, (uint32_t) src->ne[1], (uint32_t) src->ne[2], out_ne0,
|
|
2012
|
+
block_size, npr, nrows
|
|
2013
|
+
};
|
|
2014
|
+
|
|
2015
|
+
const uint32_t total_wg_init = npr * nrows;
|
|
2016
|
+
const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
|
2017
|
+
const uint32_t wg_x_init = std::min(total_wg_init, max_wg);
|
|
2018
|
+
const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
|
|
2019
|
+
std::vector<wgpu::BindGroupEntry> init_entries = {
|
|
2020
|
+
{ .binding = 0,
|
|
2021
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
2022
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
2023
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
|
2024
|
+
{ .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size }
|
|
2025
|
+
};
|
|
2026
|
+
|
|
2027
|
+
pipelines.push_back(argsort_pipeline);
|
|
2028
|
+
params_list.push_back(std::move(init_params));
|
|
2029
|
+
entries_list.push_back(std::move(init_entries));
|
|
2030
|
+
workgroups_list.push_back({ wg_x_init, wg_y_init });
|
|
2031
|
+
|
|
2032
|
+
if (merge_passes == 0) {
|
|
2033
|
+
return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
|
|
2034
|
+
entries_list, workgroups_list);
|
|
2035
|
+
}
|
|
2036
|
+
|
|
2037
|
+
bool in_is_tmp = start_in_tmp;
|
|
2038
|
+
uint32_t len = block_size;
|
|
2039
|
+
while (len < out_ne0) {
|
|
2040
|
+
const uint32_t nm = CEIL_DIV(out_ne0, 2 * len);
|
|
2041
|
+
|
|
2042
|
+
const bool out_is_tmp = !in_is_tmp;
|
|
2043
|
+
const uint32_t offset_in = in_is_tmp ? offset_tmp : offset_dst;
|
|
2044
|
+
const uint32_t offset_out = out_is_tmp ? offset_tmp : offset_dst;
|
|
2045
|
+
const size_t align_in = in_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
|
|
2046
|
+
const size_t align_out = out_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
|
|
2047
|
+
const size_t size_in = in_is_tmp ? tmp_binding_size : dst_binding_size;
|
|
2048
|
+
const size_t size_out = out_is_tmp ? tmp_binding_size : dst_binding_size;
|
|
2049
|
+
const uint32_t top_k_out = (is_top_k && nm == 1) ? (uint32_t) dst->ne[0] : out_ne0;
|
|
2050
|
+
const uint32_t stride_out1 = top_k_out;
|
|
2051
|
+
const uint32_t stride_out2 = top_k_out * (uint32_t) dst->ne[1];
|
|
2052
|
+
const uint32_t stride_out3 = stride_out2 * (uint32_t) dst->ne[2];
|
|
2053
|
+
|
|
2054
|
+
std::vector<uint32_t> merge_params = { offset_src,
|
|
2055
|
+
offset_in,
|
|
2056
|
+
offset_out,
|
|
2057
|
+
stride_src1,
|
|
2058
|
+
stride_src2,
|
|
2059
|
+
stride_src3,
|
|
2060
|
+
stride_idx1,
|
|
2061
|
+
stride_idx2,
|
|
2062
|
+
stride_idx3,
|
|
2063
|
+
stride_out1,
|
|
2064
|
+
stride_out2,
|
|
2065
|
+
stride_out3,
|
|
2066
|
+
out_ne0,
|
|
2067
|
+
(uint32_t) src->ne[1],
|
|
2068
|
+
(uint32_t) src->ne[2],
|
|
2069
|
+
top_k_out,
|
|
2070
|
+
len,
|
|
2071
|
+
nm,
|
|
2072
|
+
nrows };
|
|
2073
|
+
|
|
2074
|
+
std::vector<wgpu::BindGroupEntry> merge_entries = {
|
|
2075
|
+
{ .binding = 0,
|
|
2076
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
2077
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
2078
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
|
2079
|
+
{ .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in },
|
|
2080
|
+
{ .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out }
|
|
2081
|
+
};
|
|
2082
|
+
|
|
2083
|
+
const uint32_t total_wg_merge = nm * nrows;
|
|
2084
|
+
const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg);
|
|
2085
|
+
const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge);
|
|
2086
|
+
workgroups_list.push_back({ wg_x_merge, wg_y_merge });
|
|
2087
|
+
pipelines.push_back(argsort_merge_pipeline);
|
|
2088
|
+
params_list.push_back(std::move(merge_params));
|
|
2089
|
+
entries_list.push_back(std::move(merge_entries));
|
|
2090
|
+
|
|
2091
|
+
len <<= 1;
|
|
2092
|
+
in_is_tmp = !in_is_tmp;
|
|
686
2093
|
}
|
|
687
2094
|
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
|
2095
|
+
return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list,
|
|
2096
|
+
workgroups_list);
|
|
691
2097
|
}
|
|
692
2098
|
|
|
693
|
-
static
|
|
694
|
-
|
|
2099
|
+
static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
2100
|
+
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
2101
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
2102
|
+
(uint32_t) src->ne[0] };
|
|
695
2103
|
|
|
696
|
-
|
|
697
|
-
|
|
2104
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
2105
|
+
{ .binding = 0,
|
|
2106
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
2107
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
2108
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
|
2109
|
+
{ .binding = 1,
|
|
2110
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
2111
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
2112
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
|
2113
|
+
};
|
|
698
2114
|
|
|
699
|
-
|
|
700
|
-
|
|
2115
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
2116
|
+
.src0 = src,
|
|
2117
|
+
.src1 = nullptr,
|
|
2118
|
+
.dst = dst,
|
|
2119
|
+
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
|
701
2120
|
};
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
2121
|
+
|
|
2122
|
+
webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx);
|
|
2123
|
+
uint32_t wg_x = ggml_nrows(dst);
|
|
2124
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
|
2125
|
+
}
|
|
2126
|
+
|
|
2127
|
+
static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
2128
|
+
bool total_sum = dst->op == GGML_OP_SUM;
|
|
2129
|
+
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
2130
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
2131
|
+
total_sum ? 0 : (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
|
2132
|
+
total_sum ? 0 : (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
|
2133
|
+
total_sum ? 0 : (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
|
2134
|
+
total_sum ? static_cast<uint32_t>(ggml_nelements(src)) : (uint32_t) src->ne[0],
|
|
2135
|
+
total_sum ? 1 : (uint32_t) src->ne[1],
|
|
2136
|
+
total_sum ? 1 : (uint32_t) src->ne[2] };
|
|
718
2137
|
|
|
719
2138
|
std::vector<wgpu::BindGroupEntry> entries = {
|
|
720
2139
|
{ .binding = 0,
|
|
721
2140
|
.buffer = ggml_webgpu_tensor_buf(src),
|
|
722
2141
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
723
|
-
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
|
|
2142
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
|
2143
|
+
{ .binding = 1,
|
|
2144
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
2145
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
2146
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
|
724
2147
|
};
|
|
725
|
-
if (!in_place) {
|
|
726
|
-
entries.push_back({ .binding = 1,
|
|
727
|
-
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
728
|
-
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
729
|
-
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
730
|
-
}
|
|
731
2148
|
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
|
2149
|
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
|
2150
|
+
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
|
2151
|
+
};
|
|
2152
|
+
|
|
2153
|
+
webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx);
|
|
2154
|
+
|
|
2155
|
+
uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
|
|
2156
|
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
|
741
2157
|
}
|
|
742
2158
|
|
|
743
|
-
// Returns
|
|
744
|
-
static
|
|
2159
|
+
// Returns the encoded command, or std::nullopt if the operation is a no-op
|
|
2160
|
+
static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
|
745
2161
|
if (ggml_is_empty(node)) {
|
|
746
|
-
return
|
|
2162
|
+
return std::nullopt;
|
|
2163
|
+
}
|
|
2164
|
+
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
|
2165
|
+
return std::nullopt;
|
|
747
2166
|
}
|
|
748
2167
|
WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
|
|
749
2168
|
|
|
750
2169
|
ggml_tensor * src0 = node->src[0];
|
|
751
2170
|
ggml_tensor * src1 = node->src[1];
|
|
2171
|
+
ggml_tensor * src2 = node->src[2];
|
|
752
2172
|
|
|
753
2173
|
switch (node->op) {
|
|
754
2174
|
// no-ops
|
|
@@ -757,55 +2177,122 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
|
|
757
2177
|
case GGML_OP_PERMUTE:
|
|
758
2178
|
case GGML_OP_TRANSPOSE:
|
|
759
2179
|
case GGML_OP_RESHAPE:
|
|
760
|
-
return
|
|
2180
|
+
return std::nullopt;
|
|
761
2181
|
case GGML_OP_CPY:
|
|
762
|
-
|
|
763
|
-
|
|
2182
|
+
case GGML_OP_CONT:
|
|
2183
|
+
return ggml_webgpu_cpy(ctx, src0, node);
|
|
764
2184
|
case GGML_OP_SET_ROWS:
|
|
765
|
-
ggml_webgpu_set_rows(ctx, src0, src1, node);
|
|
766
|
-
break;
|
|
2185
|
+
return ggml_webgpu_set_rows(ctx, src0, src1, node);
|
|
767
2186
|
case GGML_OP_GET_ROWS:
|
|
768
|
-
ggml_webgpu_get_rows(ctx, src0, src1, node);
|
|
769
|
-
break;
|
|
2187
|
+
return ggml_webgpu_get_rows(ctx, src0, src1, node);
|
|
770
2188
|
case GGML_OP_MUL_MAT:
|
|
771
|
-
ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
|
772
|
-
|
|
2189
|
+
return ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
|
2190
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
|
2191
|
+
#ifndef __EMSCRIPTEN__
|
|
2192
|
+
return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
|
|
2193
|
+
#else
|
|
2194
|
+
return std::nullopt;
|
|
2195
|
+
#endif
|
|
773
2196
|
case GGML_OP_ADD:
|
|
774
|
-
|
|
775
|
-
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
|
|
776
|
-
} else {
|
|
777
|
-
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
|
|
778
|
-
}
|
|
779
|
-
break;
|
|
2197
|
+
case GGML_OP_SUB:
|
|
780
2198
|
case GGML_OP_MUL:
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
2199
|
+
case GGML_OP_DIV:
|
|
2200
|
+
return ggml_webgpu_binary_op(ctx, src0, src1, node);
|
|
2201
|
+
case GGML_OP_CONCAT:
|
|
2202
|
+
return ggml_webgpu_concat(ctx, src0, src1, node);
|
|
2203
|
+
case GGML_OP_REPEAT:
|
|
2204
|
+
return ggml_webgpu_repeat(ctx, src0, node);
|
|
787
2205
|
case GGML_OP_RMS_NORM:
|
|
788
|
-
ggml_webgpu_rms_norm(ctx, src0, node);
|
|
789
|
-
|
|
2206
|
+
return ggml_webgpu_rms_norm(ctx, src0, node);
|
|
2207
|
+
case GGML_OP_ROPE:
|
|
2208
|
+
return ggml_webgpu_rope(ctx, src0, src1, src2, node);
|
|
2209
|
+
case GGML_OP_GLU:
|
|
2210
|
+
return ggml_webgpu_glu(ctx, src0, src1, node);
|
|
2211
|
+
case GGML_OP_SCALE:
|
|
2212
|
+
return ggml_webgpu_scale(ctx, src0, node);
|
|
2213
|
+
case GGML_OP_SOFT_MAX:
|
|
2214
|
+
return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
|
|
2215
|
+
case GGML_OP_UNARY:
|
|
2216
|
+
case GGML_OP_CLAMP:
|
|
2217
|
+
case GGML_OP_FILL:
|
|
2218
|
+
case GGML_OP_LOG:
|
|
2219
|
+
case GGML_OP_SQR:
|
|
2220
|
+
case GGML_OP_SQRT:
|
|
2221
|
+
case GGML_OP_SIN:
|
|
2222
|
+
case GGML_OP_COS:
|
|
2223
|
+
return ggml_webgpu_unary_op(ctx, src0, node);
|
|
2224
|
+
case GGML_OP_PAD:
|
|
2225
|
+
return ggml_webgpu_pad(ctx, src0, node);
|
|
2226
|
+
case GGML_OP_ARGMAX:
|
|
2227
|
+
return ggml_webgpu_argmax(ctx, src0, node);
|
|
2228
|
+
case GGML_OP_ARGSORT:
|
|
2229
|
+
case GGML_OP_TOP_K:
|
|
2230
|
+
// we reuse the same argsort implementation for top_k
|
|
2231
|
+
return ggml_webgpu_argsort(ctx, src0, node);
|
|
2232
|
+
case GGML_OP_CUMSUM:
|
|
2233
|
+
return ggml_webgpu_cumsum(ctx, src0, node);
|
|
2234
|
+
case GGML_OP_SUM:
|
|
2235
|
+
case GGML_OP_SUM_ROWS:
|
|
2236
|
+
return ggml_webgpu_sum_rows(ctx, src0, node);
|
|
790
2237
|
default:
|
|
791
|
-
return
|
|
2238
|
+
return std::nullopt;
|
|
792
2239
|
}
|
|
793
|
-
return true;
|
|
794
2240
|
}
|
|
795
2241
|
|
|
796
2242
|
static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
|
797
2243
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
|
|
798
2244
|
|
|
799
|
-
ggml_backend_webgpu_context * backend_ctx =
|
|
2245
|
+
ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
|
|
800
2246
|
webgpu_context ctx = backend_ctx->webgpu_ctx;
|
|
801
2247
|
|
|
2248
|
+
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
|
|
2249
|
+
|
|
2250
|
+
std::vector<webgpu_command> commands;
|
|
2251
|
+
std::vector<webgpu_submission> subs;
|
|
2252
|
+
uint32_t num_batched_kernels = 0;
|
|
2253
|
+
bool contains_set_rows = false;
|
|
2254
|
+
|
|
802
2255
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
803
|
-
|
|
2256
|
+
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
|
|
2257
|
+
contains_set_rows = true;
|
|
2258
|
+
}
|
|
2259
|
+
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
|
|
2260
|
+
commands.push_back(*cmd);
|
|
2261
|
+
num_batched_kernels += cmd.value().num_kernels;
|
|
2262
|
+
}
|
|
2263
|
+
|
|
2264
|
+
if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
|
|
2265
|
+
num_batched_kernels = 0;
|
|
2266
|
+
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
|
|
2267
|
+
// Process events and check for completed submissions
|
|
2268
|
+
ctx->global_ctx->instance.ProcessEvents();
|
|
2269
|
+
ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);
|
|
2270
|
+
commands.clear();
|
|
2271
|
+
}
|
|
2272
|
+
}
|
|
2273
|
+
if (!commands.empty()) {
|
|
2274
|
+
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
|
|
2275
|
+
commands.clear();
|
|
804
2276
|
}
|
|
805
2277
|
|
|
806
|
-
|
|
807
|
-
|
|
2278
|
+
// If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking.
|
|
2279
|
+
if (contains_set_rows) {
|
|
2280
|
+
wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
|
|
2281
|
+
encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
|
|
2282
|
+
ctx->set_rows_host_error_buf.GetSize());
|
|
2283
|
+
wgpu::CommandBuffer set_rows_commands = encoder.Finish();
|
|
2284
|
+
ctx->global_ctx->queue.Submit(1, &set_rows_commands);
|
|
2285
|
+
ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,
|
|
2286
|
+
ctx->set_rows_host_error_buf.GetSize());
|
|
2287
|
+
const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();
|
|
2288
|
+
if (*error_data) {
|
|
2289
|
+
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
|
|
2290
|
+
}
|
|
2291
|
+
ctx->set_rows_host_error_buf.Unmap();
|
|
2292
|
+
}
|
|
808
2293
|
|
|
2294
|
+
ggml_backend_webgpu_wait(ctx->global_ctx, subs);
|
|
2295
|
+
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
|
|
809
2296
|
return GGML_STATUS_SUCCESS;
|
|
810
2297
|
}
|
|
811
2298
|
|
|
@@ -831,9 +2318,11 @@ static ggml_backend_i ggml_backend_webgpu_i = {
|
|
|
831
2318
|
/* GGML Backend Buffer Interface */
|
|
832
2319
|
|
|
833
2320
|
static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
834
|
-
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_free_buffer()");
|
|
835
2321
|
ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
|
|
836
|
-
ctx->buffer
|
|
2322
|
+
if (ctx != nullptr && ctx->buffer != nullptr) {
|
|
2323
|
+
ctx->buffer.Destroy();
|
|
2324
|
+
delete ctx;
|
|
2325
|
+
}
|
|
837
2326
|
}
|
|
838
2327
|
|
|
839
2328
|
// Returns the "fake" base pointer.
|
|
@@ -848,20 +2337,25 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
|
|
|
848
2337
|
size_t offset,
|
|
849
2338
|
size_t size) {
|
|
850
2339
|
if (size == 0) {
|
|
851
|
-
WEBGPU_LOG_DEBUG(
|
|
2340
|
+
WEBGPU_LOG_DEBUG(
|
|
2341
|
+
"ggml_backend_webgpu_buffer_memset_tensor: size is zero, "
|
|
2342
|
+
"nothing to do.");
|
|
852
2343
|
return;
|
|
853
2344
|
}
|
|
854
2345
|
|
|
855
|
-
|
|
856
|
-
<< offset << ", " << size << ")");
|
|
2346
|
+
WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
|
|
857
2347
|
|
|
858
2348
|
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
|
859
2349
|
|
|
2350
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
|
|
2351
|
+
<< ", " << offset << ", " << size << ")");
|
|
2352
|
+
|
|
860
2353
|
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
861
2354
|
|
|
862
2355
|
// This is a trick to set all bytes of a u32 to the same 1 byte value.
|
|
863
2356
|
uint32_t val32 = (uint32_t) value * 0x01010101;
|
|
864
|
-
ggml_backend_webgpu_buffer_memset(buf_ctx->
|
|
2357
|
+
ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset, size);
|
|
2358
|
+
WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->global_ctx);
|
|
865
2359
|
}
|
|
866
2360
|
|
|
867
2361
|
static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|
@@ -869,14 +2363,15 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|
|
869
2363
|
const void * data,
|
|
870
2364
|
size_t offset,
|
|
871
2365
|
size_t size) {
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
2366
|
+
WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
|
|
2367
|
+
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
|
2368
|
+
|
|
2369
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
|
|
2370
|
+
<< ", " << offset << ", " << size << ")");
|
|
876
2371
|
|
|
877
2372
|
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
878
2373
|
|
|
879
|
-
|
|
2374
|
+
buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
|
|
880
2375
|
|
|
881
2376
|
if (size % 4 != 0) {
|
|
882
2377
|
// If size is not a multiple of 4, we need to memset the remaining bytes
|
|
@@ -889,12 +2384,21 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|
|
889
2384
|
((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
|
|
890
2385
|
}
|
|
891
2386
|
// memset the remaining bytes
|
|
892
|
-
ggml_backend_webgpu_buffer_memset(
|
|
893
|
-
remaining_size);
|
|
2387
|
+
ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
|
|
2388
|
+
total_offset + (size - remaining_size), remaining_size);
|
|
894
2389
|
} else {
|
|
895
2390
|
// wait for WriteBuffer to complete
|
|
896
|
-
|
|
2391
|
+
buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
|
|
2392
|
+
wgpu::CallbackMode::AllowSpontaneous,
|
|
2393
|
+
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
|
2394
|
+
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
|
2395
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
|
|
2396
|
+
std::string(message).c_str());
|
|
2397
|
+
}
|
|
2398
|
+
}),
|
|
2399
|
+
UINT64_MAX);
|
|
897
2400
|
}
|
|
2401
|
+
WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
|
|
898
2402
|
}
|
|
899
2403
|
|
|
900
2404
|
static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|
@@ -902,54 +2406,60 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|
|
902
2406
|
void * data,
|
|
903
2407
|
size_t offset,
|
|
904
2408
|
size_t size) {
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
wgpu::Device device = webgpu_ctx->device;
|
|
2409
|
+
WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
|
|
2410
|
+
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
|
2411
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
|
|
2412
|
+
<< ", " << offset << ", " << size << ")");
|
|
2413
|
+
wgpu::Device device = buf_ctx->global_ctx->device;
|
|
911
2414
|
|
|
912
2415
|
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
913
2416
|
|
|
914
2417
|
size_t final_size = size;
|
|
915
2418
|
if (size % 4 != 0) {
|
|
916
|
-
// If size is not a multiple of 4, we need to round it up to the next
|
|
2419
|
+
// If size is not a multiple of 4, we need to round it up to the next
|
|
2420
|
+
// multiple of 4
|
|
917
2421
|
final_size = size + (4 - (size % 4));
|
|
918
2422
|
}
|
|
919
2423
|
|
|
920
|
-
std::lock_guard<std::recursive_mutex> lock(
|
|
2424
|
+
std::lock_guard<std::recursive_mutex> lock(buf_ctx->global_ctx->mutex);
|
|
921
2425
|
|
|
922
|
-
if (
|
|
2426
|
+
if (buf_ctx->global_ctx->get_tensor_staging_buf == nullptr ||
|
|
2427
|
+
buf_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) {
|
|
923
2428
|
// Create a new staging buffer if it doesn't exist or is too small
|
|
924
|
-
if (
|
|
925
|
-
|
|
2429
|
+
if (buf_ctx->global_ctx->get_tensor_staging_buf) {
|
|
2430
|
+
buf_ctx->global_ctx->get_tensor_staging_buf.Destroy();
|
|
926
2431
|
}
|
|
927
|
-
ggml_webgpu_create_buffer(device,
|
|
2432
|
+
ggml_webgpu_create_buffer(device, buf_ctx->global_ctx->get_tensor_staging_buf, final_size,
|
|
928
2433
|
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
|
|
929
2434
|
}
|
|
930
2435
|
|
|
931
2436
|
// Copy the data from the buffer to the staging buffer
|
|
932
2437
|
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
|
933
|
-
encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset,
|
|
2438
|
+
encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, buf_ctx->global_ctx->get_tensor_staging_buf, 0,
|
|
2439
|
+
final_size);
|
|
934
2440
|
wgpu::CommandBuffer commands = encoder.Finish();
|
|
935
2441
|
|
|
936
2442
|
// Submit the command buffer to the queue
|
|
937
|
-
|
|
2443
|
+
buf_ctx->global_ctx->queue.Submit(1, &commands);
|
|
938
2444
|
|
|
939
2445
|
// Map the staging buffer to read the data
|
|
940
|
-
ggml_backend_webgpu_map_buffer(
|
|
2446
|
+
ggml_backend_webgpu_map_buffer(buf_ctx->global_ctx, buf_ctx->global_ctx->get_tensor_staging_buf,
|
|
2447
|
+
wgpu::MapMode::Read, 0, final_size);
|
|
941
2448
|
// Must specify size here since the staging buffer might be larger than the tensor size
|
|
942
|
-
const void * mapped_range =
|
|
2449
|
+
const void * mapped_range = buf_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
|
|
943
2450
|
|
|
944
2451
|
// Copy the data from the mapped range to the output buffer
|
|
945
2452
|
std::memcpy(data, mapped_range, size);
|
|
946
|
-
|
|
2453
|
+
buf_ctx->global_ctx->get_tensor_staging_buf.Unmap();
|
|
2454
|
+
WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, buf_ctx->global_ctx);
|
|
947
2455
|
}
|
|
948
2456
|
|
|
949
2457
|
static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|
950
2458
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
|
|
2459
|
+
WEBGPU_CPU_PROFILE_TOTAL_START(clear);
|
|
951
2460
|
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
|
952
|
-
ggml_backend_webgpu_buffer_memset(buf_ctx->
|
|
2461
|
+
ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, value, 0, buffer->size);
|
|
2462
|
+
WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->global_ctx);
|
|
953
2463
|
}
|
|
954
2464
|
|
|
955
2465
|
static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
|
|
@@ -961,7 +2471,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
|
|
|
961
2471
|
/* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
|
|
962
2472
|
/* .cpy_tensor = */ NULL, // TODO: optional, implement this
|
|
963
2473
|
/* .clear = */ ggml_backend_webgpu_buffer_clear,
|
|
964
|
-
/* .reset = */ NULL, // TODO: optional, think it coordinates with
|
|
2474
|
+
/* .reset = */ NULL, // TODO: optional, think it coordinates with
|
|
2475
|
+
// .init_tensor
|
|
965
2476
|
};
|
|
966
2477
|
|
|
967
2478
|
/* End GGML Backend Buffer Interface */
|
|
@@ -975,29 +2486,61 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer
|
|
|
975
2486
|
|
|
976
2487
|
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
|
977
2488
|
size_t size) {
|
|
978
|
-
|
|
979
|
-
|
|
2489
|
+
static std::atomic<int> buffer_count;
|
|
2490
|
+
int buffer_id = buffer_count++;
|
|
2491
|
+
std::string buf_name = "tensor_buf" + std::to_string(buffer_id);
|
|
2492
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
|
|
980
2493
|
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
2494
|
+
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
|
2495
|
+
wgpu::Buffer buf;
|
|
2496
|
+
ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
|
|
984
2497
|
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
|
|
985
|
-
|
|
2498
|
+
buf_name.c_str());
|
|
986
2499
|
|
|
987
|
-
ggml_backend_webgpu_buffer_context * buf_ctx =
|
|
2500
|
+
ggml_backend_webgpu_buffer_context * buf_ctx =
|
|
2501
|
+
new ggml_backend_webgpu_buffer_context(buf, buf_name, ctx->webgpu_global_ctx);
|
|
988
2502
|
|
|
989
2503
|
return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
|
|
990
2504
|
}
|
|
991
2505
|
|
|
992
2506
|
static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
993
|
-
ggml_backend_webgpu_device_context *
|
|
994
|
-
|
|
2507
|
+
ggml_backend_webgpu_device_context * dev_ctx =
|
|
2508
|
+
static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
|
2509
|
+
return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
|
|
995
2510
|
}
|
|
996
2511
|
|
|
997
|
-
// maxBufferSize might be larger, but you can't bind more than
|
|
2512
|
+
// maxBufferSize might be larger, but you can't bind more than
|
|
2513
|
+
// maxStorageBufferBindingSize to a single binding.
|
|
998
2514
|
static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
|
2515
|
+
ggml_backend_webgpu_device_context * dev_ctx =
|
|
2516
|
+
static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
|
2517
|
+
return dev_ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize;
|
|
2518
|
+
}
|
|
2519
|
+
|
|
2520
|
+
static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
|
|
2521
|
+
const ggml_tensor * tensor) {
|
|
999
2522
|
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
|
1000
|
-
|
|
2523
|
+
size_t res = ggml_nbytes(tensor);
|
|
2524
|
+
switch (tensor->op) {
|
|
2525
|
+
case GGML_OP_ARGSORT:
|
|
2526
|
+
res = ROUNDUP_POW2(res * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
|
|
2527
|
+
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
|
2528
|
+
break;
|
|
2529
|
+
case GGML_OP_TOP_K:
|
|
2530
|
+
{
|
|
2531
|
+
const ggml_tensor * src0 = tensor->src[0];
|
|
2532
|
+
if (src0) {
|
|
2533
|
+
const size_t full = sizeof(int32_t) * ggml_nelements(src0);
|
|
2534
|
+
res = ROUNDUP_POW2(
|
|
2535
|
+
full * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
|
|
2536
|
+
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
|
2537
|
+
}
|
|
2538
|
+
}
|
|
2539
|
+
break;
|
|
2540
|
+
default:
|
|
2541
|
+
break;
|
|
2542
|
+
}
|
|
2543
|
+
return res;
|
|
1001
2544
|
}
|
|
1002
2545
|
|
|
1003
2546
|
/* End GGML Backend Buffer Type Interface */
|
|
@@ -1016,9 +2559,18 @@ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_
|
|
|
1016
2559
|
|
|
1017
2560
|
static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
1018
2561
|
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
|
1019
|
-
// TODO:
|
|
1020
|
-
|
|
1021
|
-
|
|
2562
|
+
// TODO: for now, return maxBufferSize as both free and total memory
|
|
2563
|
+
// Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
|
|
2564
|
+
uint64_t max_buffer_size = ctx->webgpu_global_ctx->capabilities.limits.maxBufferSize;
|
|
2565
|
+
// If we're on a 32-bit system, clamp to UINTPTR_MAX
|
|
2566
|
+
#if UINTPTR_MAX < UINT64_MAX
|
|
2567
|
+
uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);
|
|
2568
|
+
if (max_buffer_size > max_ptr_size) {
|
|
2569
|
+
max_buffer_size = max_ptr_size;
|
|
2570
|
+
}
|
|
2571
|
+
#endif
|
|
2572
|
+
*free = static_cast<size_t>(max_buffer_size);
|
|
2573
|
+
*total = static_cast<size_t>(max_buffer_size);
|
|
1022
2574
|
}
|
|
1023
2575
|
|
|
1024
2576
|
static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
|
|
@@ -1044,205 +2596,382 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
|
|
|
1044
2596
|
return reinterpret_cast<ggml_guid_t>((void *) guid_str);
|
|
1045
2597
|
}
|
|
1046
2598
|
|
|
1047
|
-
|
|
1048
|
-
static std::vector<wgpu::ConstantEntry> ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
|
|
1049
|
-
std::vector<wgpu::ConstantEntry> constants(1);
|
|
1050
|
-
constants[0].key = "wg_size";
|
|
1051
|
-
constants[0].value = webgpu_ctx->max_wg_size_x;
|
|
1052
|
-
return constants;
|
|
1053
|
-
}
|
|
1054
|
-
|
|
1055
|
-
static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
|
|
2599
|
+
static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
|
|
1056
2600
|
// we use the maximum workgroup size for the memset pipeline
|
|
1057
|
-
size_t
|
|
1058
|
-
size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
|
|
2601
|
+
size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
|
1059
2602
|
// Size the bytes_per_thread so that the largest buffer size can be handled
|
|
1060
|
-
|
|
1061
|
-
(
|
|
2603
|
+
ctx->capabilities.memset_bytes_per_thread =
|
|
2604
|
+
CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
|
|
1062
2605
|
std::vector<wgpu::ConstantEntry> constants(2);
|
|
1063
|
-
constants[0].key
|
|
1064
|
-
constants[0].value
|
|
1065
|
-
constants[1].key
|
|
1066
|
-
constants[1].value
|
|
1067
|
-
ggml_webgpu_create_pipeline(
|
|
1068
|
-
}
|
|
1069
|
-
|
|
1070
|
-
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
|
1071
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
|
|
1072
|
-
wgsl_mul_mat_f32_f32, "mul_mat_f32_f32");
|
|
1073
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
|
|
1074
|
-
wgsl_mul_mat_f16_f16, "mul_mat_f16_f16");
|
|
1075
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
|
|
1076
|
-
wgsl_mul_mat_f16_f32, "mul_mat_f16_f32");
|
|
1077
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
|
|
1078
|
-
wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
|
|
1079
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
|
|
1080
|
-
wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
|
|
1081
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
|
|
1082
|
-
wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
|
|
1083
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
|
|
1084
|
-
wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
|
|
1085
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
|
|
1086
|
-
wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
|
|
1087
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
|
|
1088
|
-
wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
|
|
1089
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
|
|
1090
|
-
wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
|
|
1091
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
|
|
1092
|
-
wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
|
|
1093
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
|
|
1094
|
-
wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
|
|
1095
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
|
|
1096
|
-
wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
|
|
1097
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
|
|
1098
|
-
wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
|
|
1099
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
|
|
1100
|
-
wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
|
|
1101
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
|
|
1102
|
-
wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
|
|
1103
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
|
|
1104
|
-
wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
|
|
1105
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
|
|
1106
|
-
wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
|
|
1107
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
|
|
1108
|
-
wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
|
|
1109
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
|
|
1110
|
-
wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
|
|
1111
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
|
|
1112
|
-
wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
|
|
1113
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
|
|
1114
|
-
wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
|
|
1115
|
-
}
|
|
1116
|
-
|
|
1117
|
-
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
|
1118
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
|
|
1119
|
-
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
|
|
1120
|
-
}
|
|
1121
|
-
|
|
1122
|
-
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
|
1123
|
-
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
|
1124
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec,
|
|
1125
|
-
"get_rows_f32_vec", constants);
|
|
1126
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32,
|
|
1127
|
-
"get_rows_f32", constants);
|
|
1128
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F16], wgsl_get_rows_f16,
|
|
1129
|
-
"get_rows_f16", constants);
|
|
1130
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_I32], wgsl_get_rows_i32,
|
|
1131
|
-
"get_rows_i32", constants);
|
|
1132
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_0], wgsl_get_rows_q4_0,
|
|
1133
|
-
"get_rows_q4_0", constants);
|
|
1134
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_1], wgsl_get_rows_q4_1,
|
|
1135
|
-
"get_rows_q4_1", constants);
|
|
1136
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_0], wgsl_get_rows_q5_0,
|
|
1137
|
-
"get_rows_q5_0", constants);
|
|
1138
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_1], wgsl_get_rows_q5_1,
|
|
1139
|
-
"get_rows_q5_1", constants);
|
|
1140
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q8_0], wgsl_get_rows_q8_0,
|
|
1141
|
-
"get_rows_q8_0", constants);
|
|
1142
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q2_K], wgsl_get_rows_q2_k,
|
|
1143
|
-
"get_rows_q2_k", constants);
|
|
1144
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q3_K], wgsl_get_rows_q3_k,
|
|
1145
|
-
"get_rows_q3_k", constants);
|
|
1146
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_K], wgsl_get_rows_q4_k,
|
|
1147
|
-
"get_rows_q4_k", constants);
|
|
1148
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_K], wgsl_get_rows_q5_k,
|
|
1149
|
-
"get_rows_q5_k", constants);
|
|
1150
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q6_K], wgsl_get_rows_q6_k,
|
|
1151
|
-
"get_rows_q6_k", constants);
|
|
1152
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XXS],
|
|
1153
|
-
wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
|
|
1154
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XS],
|
|
1155
|
-
wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
|
|
1156
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_S], wgsl_get_rows_iq2_s,
|
|
1157
|
-
"get_rows_iq2_s", constants);
|
|
1158
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_XXS],
|
|
1159
|
-
wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
|
|
1160
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_S], wgsl_get_rows_iq3_s,
|
|
1161
|
-
"get_rows_iq3_s", constants);
|
|
1162
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_S], wgsl_get_rows_iq1_s,
|
|
1163
|
-
"get_rows_iq1_s", constants);
|
|
1164
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_M], wgsl_get_rows_iq1_m,
|
|
1165
|
-
"get_rows_iq1_m", constants);
|
|
1166
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_NL],
|
|
1167
|
-
wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
|
|
1168
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_XS],
|
|
1169
|
-
wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
|
|
2606
|
+
constants[0].key = "wg_size";
|
|
2607
|
+
constants[0].value = WEBGPU_MAX_WG_SIZE;
|
|
2608
|
+
constants[1].key = "bytes_per_thread";
|
|
2609
|
+
constants[1].value = ctx->capabilities.memset_bytes_per_thread;
|
|
2610
|
+
ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
|
|
1170
2611
|
}
|
|
1171
2612
|
|
|
1172
2613
|
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16,
|
|
1186
|
-
"add_in_place_f16", constants);
|
|
1187
|
-
}
|
|
1188
|
-
|
|
1189
|
-
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
|
1190
|
-
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
|
1191
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32",
|
|
1192
|
-
constants);
|
|
1193
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
|
|
1194
|
-
constants);
|
|
1195
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32,
|
|
1196
|
-
"mul_in_place_f32", constants);
|
|
1197
|
-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16,
|
|
1198
|
-
"mul_in_place_f16", constants);
|
|
2614
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
|
2615
|
+
|
|
2616
|
+
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
|
|
2617
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
|
|
2618
|
+
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
|
|
2619
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
|
|
2620
|
+
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
|
|
2621
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
|
|
2622
|
+
webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
|
|
2623
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
|
|
2624
|
+
webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
|
|
2625
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
|
|
1199
2626
|
}
|
|
1200
2627
|
|
|
1201
2628
|
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
|
1202
|
-
std::vector<wgpu::ConstantEntry> constants =
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
2629
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
|
|
2630
|
+
|
|
2631
|
+
webgpu_ctx->rms_norm_pipelines[0] =
|
|
2632
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
|
|
2633
|
+
webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
|
|
2634
|
+
webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
|
|
2635
|
+
}
|
|
2636
|
+
|
|
2637
|
+
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
|
|
2638
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
|
2639
|
+
|
|
2640
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
|
|
2641
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
|
|
2642
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
|
|
2643
|
+
webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
|
|
2644
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
|
|
2645
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
|
|
2646
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
|
|
2647
|
+
webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
|
|
2648
|
+
|
|
2649
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
|
|
2650
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
|
|
2651
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
|
|
2652
|
+
webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
|
|
2653
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
|
|
2654
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
|
|
2655
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
|
|
2656
|
+
webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
|
|
2657
|
+
}
|
|
2658
|
+
|
|
2659
|
+
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
|
|
2660
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
|
2661
|
+
|
|
2662
|
+
// REGLU
|
|
2663
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
|
|
2664
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
|
|
2665
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
|
|
2666
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
|
|
2667
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
|
|
2668
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
|
|
2669
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
|
|
2670
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
|
|
2671
|
+
|
|
2672
|
+
// GEGLU
|
|
2673
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
|
|
2674
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
|
|
2675
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
|
|
2676
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
|
|
2677
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
|
|
2678
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
|
|
2679
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
|
|
2680
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
|
|
2681
|
+
|
|
2682
|
+
// SWIGLU
|
|
2683
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
|
|
2684
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
|
|
2685
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
|
|
2686
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
|
|
2687
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
|
2688
|
+
webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
|
|
2689
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
|
|
2690
|
+
webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
|
|
2691
|
+
|
|
2692
|
+
// SWIGLU_OAI
|
|
2693
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
|
|
2694
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
|
|
2695
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
|
2696
|
+
webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
|
|
2697
|
+
|
|
2698
|
+
// GEGLU_ERF
|
|
2699
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
|
|
2700
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
|
|
2701
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
|
|
2702
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
|
|
2703
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
|
2704
|
+
webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
|
|
2705
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
|
|
2706
|
+
webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
|
|
2707
|
+
|
|
2708
|
+
// GEGLU_QUICK
|
|
2709
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
|
|
2710
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
|
|
2711
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
|
|
2712
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
|
|
2713
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
|
2714
|
+
webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
|
|
2715
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
|
|
2716
|
+
webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
|
|
2717
|
+
}
|
|
2718
|
+
|
|
2719
|
+
static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
|
|
2720
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
|
|
2721
|
+
|
|
2722
|
+
// f32 (no mask)
|
|
2723
|
+
webgpu_ctx->soft_max_pipelines[2][0][0] =
|
|
2724
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
|
|
2725
|
+
webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
|
|
2726
|
+
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
|
|
2727
|
+
webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
|
|
2728
|
+
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
|
|
2729
|
+
webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
|
|
2730
|
+
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
|
|
2731
|
+
|
|
2732
|
+
// f32 mask (mask_type = 0)
|
|
2733
|
+
webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
|
|
2734
|
+
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
|
|
2735
|
+
webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
|
|
2736
|
+
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
|
|
2737
|
+
webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
|
|
2738
|
+
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
|
|
2739
|
+
webgpu_ctx->soft_max_pipelines[0][1][1] =
|
|
2740
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
|
|
2741
|
+
"soft_max_f32_mask_f32_sink_inplace", constants);
|
|
2742
|
+
|
|
2743
|
+
// f16 mask (mask_type = 1)
|
|
2744
|
+
webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
|
|
2745
|
+
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
|
|
2746
|
+
webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
|
|
2747
|
+
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
|
|
2748
|
+
webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
|
|
2749
|
+
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
|
|
2750
|
+
webgpu_ctx->soft_max_pipelines[1][1][1] =
|
|
2751
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
|
|
2752
|
+
"soft_max_f32_mask_f16_sink_inplace", constants);
|
|
2753
|
+
}
|
|
2754
|
+
|
|
2755
|
+
static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|
2756
|
+
wgpu::RequestAdapterOptions options = {};
|
|
2757
|
+
|
|
2758
|
+
#ifndef __EMSCRIPTEN__
|
|
2759
|
+
// TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
|
|
2760
|
+
const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
|
|
2761
|
+
wgpu::DawnTogglesDescriptor adapterTogglesDesc;
|
|
2762
|
+
adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
|
|
2763
|
+
adapterTogglesDesc.enabledToggleCount = 2;
|
|
2764
|
+
options.nextInChain = &adapterTogglesDesc;
|
|
2765
|
+
#endif
|
|
2766
|
+
|
|
2767
|
+
ctx->webgpu_global_ctx->instance.WaitAny(
|
|
2768
|
+
ctx->webgpu_global_ctx->instance.RequestAdapter(
|
|
2769
|
+
&options, wgpu::CallbackMode::AllowSpontaneous,
|
|
2770
|
+
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
|
2771
|
+
if (status != wgpu::RequestAdapterStatus::Success) {
|
|
2772
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
|
2773
|
+
return;
|
|
2774
|
+
}
|
|
2775
|
+
ctx->webgpu_global_ctx->adapter = std::move(adapter);
|
|
2776
|
+
}),
|
|
2777
|
+
UINT64_MAX);
|
|
2778
|
+
GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);
|
|
2779
|
+
|
|
2780
|
+
ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);
|
|
2781
|
+
|
|
2782
|
+
wgpu::AdapterInfo info{};
|
|
2783
|
+
#ifndef __EMSCRIPTEN__
|
|
2784
|
+
wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
|
|
2785
|
+
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
|
2786
|
+
info.nextInChain = &subgroup_matrix_configs;
|
|
2787
|
+
}
|
|
2788
|
+
#endif
|
|
2789
|
+
ctx->webgpu_global_ctx->adapter.GetInfo(&info);
|
|
2790
|
+
wgpu::SupportedFeatures features;
|
|
2791
|
+
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
|
|
2792
|
+
// we require f16 support
|
|
2793
|
+
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
|
2794
|
+
|
|
2795
|
+
#ifndef __EMSCRIPTEN__
|
|
2796
|
+
// Only support square f16 matrices of size 8 or 16 for now
|
|
2797
|
+
bool valid_subgroup_matrix_config = false;
|
|
2798
|
+
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
|
2799
|
+
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
|
|
2800
|
+
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
|
|
2801
|
+
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
|
|
2802
|
+
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
|
2803
|
+
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
|
|
2804
|
+
ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
|
|
2805
|
+
ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
|
|
2806
|
+
ctx->webgpu_global_ctx->capabilities.sg_mat_k = config.K;
|
|
2807
|
+
valid_subgroup_matrix_config = true;
|
|
2808
|
+
break;
|
|
2809
|
+
}
|
|
2810
|
+
}
|
|
2811
|
+
}
|
|
2812
|
+
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
|
|
2813
|
+
#endif
|
|
2814
|
+
|
|
2815
|
+
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
|
2816
|
+
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
|
2817
|
+
ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;
|
|
2818
|
+
// Initialize device
|
|
2819
|
+
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
|
|
2820
|
+
|
|
2821
|
+
#ifndef __EMSCRIPTEN__
|
|
2822
|
+
required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
|
|
2823
|
+
if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
|
2824
|
+
required_features.push_back(wgpu::FeatureName::Subgroups);
|
|
2825
|
+
required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
|
|
2826
|
+
}
|
|
2827
|
+
#endif
|
|
2828
|
+
|
|
2829
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
2830
|
+
required_features.push_back(wgpu::FeatureName::TimestampQuery);
|
|
2831
|
+
#endif
|
|
2832
|
+
|
|
2833
|
+
wgpu::DeviceDescriptor dev_desc;
|
|
2834
|
+
dev_desc.requiredLimits = &ctx->webgpu_global_ctx->capabilities.limits;
|
|
2835
|
+
dev_desc.requiredFeatures = required_features.data();
|
|
2836
|
+
dev_desc.requiredFeatureCount = required_features.size();
|
|
2837
|
+
dev_desc.SetDeviceLostCallback(
|
|
2838
|
+
wgpu::CallbackMode::AllowSpontaneous,
|
|
2839
|
+
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
|
2840
|
+
if (reason == wgpu::DeviceLostReason::Destroyed) {
|
|
2841
|
+
return;
|
|
2842
|
+
}
|
|
2843
|
+
GGML_UNUSED(device);
|
|
2844
|
+
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
|
2845
|
+
std::string(message).c_str());
|
|
2846
|
+
});
|
|
2847
|
+
dev_desc.SetUncapturedErrorCallback(
|
|
2848
|
+
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
|
|
2849
|
+
GGML_UNUSED(device);
|
|
2850
|
+
GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
|
2851
|
+
std::string(message).c_str());
|
|
2852
|
+
});
|
|
2853
|
+
|
|
2854
|
+
#ifndef __EMSCRIPTEN__
|
|
2855
|
+
// Enable Dawn-specific toggles to increase native performance
|
|
2856
|
+
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
|
|
2857
|
+
// only for native performance?
|
|
2858
|
+
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
|
|
2859
|
+
"disable_polyfills_on_integer_div_and_mod" };
|
|
2860
|
+
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
|
|
2861
|
+
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
|
|
2862
|
+
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
|
|
2863
|
+
deviceTogglesDesc.enabledToggleCount = 4;
|
|
2864
|
+
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
|
|
2865
|
+
deviceTogglesDesc.disabledToggleCount = 1;
|
|
2866
|
+
|
|
2867
|
+
dev_desc.nextInChain = &deviceTogglesDesc;
|
|
2868
|
+
#endif
|
|
2869
|
+
|
|
2870
|
+
ctx->webgpu_global_ctx->instance.WaitAny(
|
|
2871
|
+
ctx->webgpu_global_ctx->adapter.RequestDevice(
|
|
2872
|
+
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
|
|
2873
|
+
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
|
2874
|
+
if (status != wgpu::RequestDeviceStatus::Success) {
|
|
2875
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
|
|
2876
|
+
return;
|
|
2877
|
+
}
|
|
2878
|
+
ctx->webgpu_global_ctx->device = std::move(device);
|
|
2879
|
+
}),
|
|
2880
|
+
UINT64_MAX);
|
|
2881
|
+
GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr);
|
|
2882
|
+
|
|
2883
|
+
ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx);
|
|
2884
|
+
ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
|
2885
|
+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
|
2886
|
+
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
|
|
2887
|
+
ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue();
|
|
2888
|
+
|
|
2889
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
2890
|
+
// Initialize buffer pool for timestamp queries, used for profiling
|
|
2891
|
+
ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(
|
|
2892
|
+
ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
|
|
2893
|
+
wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
|
|
2894
|
+
wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
|
|
2895
|
+
#endif
|
|
2896
|
+
|
|
2897
|
+
GGML_LOG_INFO(
|
|
2898
|
+
"ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
|
|
2899
|
+
"device_desc: %s\n",
|
|
2900
|
+
info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
|
|
2901
|
+
std::string(info.device).c_str(), std::string(info.description).c_str());
|
|
2902
|
+
return true;
|
|
1207
2903
|
}
|
|
1208
2904
|
|
|
1209
|
-
static
|
|
2905
|
+
static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
|
2906
|
+
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
|
|
2907
|
+
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
|
|
2908
|
+
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
|
|
2909
|
+
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
|
|
2910
|
+
webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
|
2911
|
+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
|
2912
|
+
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
|
|
2913
|
+
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
|
|
2914
|
+
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
|
2915
|
+
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");
|
|
2916
|
+
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf,
|
|
2917
|
+
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
|
2918
|
+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
|
|
2919
|
+
|
|
2920
|
+
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
|
|
2921
|
+
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
|
|
2922
|
+
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
|
|
2923
|
+
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
|
|
2924
|
+
ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
|
|
2925
|
+
#ifdef GGML_WEBGPU_DEBUG
|
|
2926
|
+
// Initialize debug buffers
|
|
2927
|
+
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,
|
|
2928
|
+
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
|
2929
|
+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
|
|
2930
|
+
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_dev_buf,
|
|
2931
|
+
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
|
2932
|
+
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
|
|
2933
|
+
#endif
|
|
2934
|
+
return webgpu_ctx;
|
|
2935
|
+
}
|
|
2936
|
+
|
|
2937
|
+
static ggml_backend_t ggml_backend_webgpu_backend_init(ggml_backend_dev_t dev, const char * params) {
|
|
1210
2938
|
GGML_UNUSED(params);
|
|
1211
2939
|
|
|
1212
|
-
WEBGPU_LOG_DEBUG("
|
|
2940
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_backend_init()");
|
|
1213
2941
|
|
|
1214
|
-
ggml_backend_webgpu_device_context * dev_ctx
|
|
1215
|
-
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
|
|
2942
|
+
ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
|
1216
2943
|
|
|
1217
|
-
|
|
1218
|
-
backend_ctx
|
|
1219
|
-
backend_ctx
|
|
2944
|
+
auto * backend_ctx = new ggml_backend_webgpu_context();
|
|
2945
|
+
backend_ctx->name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
|
|
2946
|
+
backend_ctx->webgpu_ctx = initialize_webgpu_context(dev);
|
|
1220
2947
|
|
|
1221
2948
|
// See GGML Backend Interface section
|
|
1222
|
-
|
|
2949
|
+
auto * backend = new ggml_backend();
|
|
2950
|
+
*backend = {
|
|
1223
2951
|
/* .guid = */ ggml_backend_webgpu_guid(),
|
|
1224
2952
|
/* .interface = */ ggml_backend_webgpu_i,
|
|
1225
2953
|
/* .device = */ dev,
|
|
1226
|
-
/* .context = */
|
|
2954
|
+
/* .context = */ backend_ctx,
|
|
1227
2955
|
};
|
|
1228
|
-
|
|
1229
|
-
return &backend;
|
|
2956
|
+
return backend;
|
|
1230
2957
|
}
|
|
1231
2958
|
|
|
1232
2959
|
static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
|
|
1233
2960
|
// See GGML Backend Buffer Type Interface section
|
|
2961
|
+
|
|
1234
2962
|
static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
|
|
1235
2963
|
/* .iface = */ {
|
|
1236
2964
|
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
|
|
1237
|
-
/* .alloc_buffer = */
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
/* .is_host = */ NULL, // defaults to false
|
|
2965
|
+
/* .alloc_buffer = */
|
|
2966
|
+
ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */
|
|
2967
|
+
ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */
|
|
2968
|
+
ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */
|
|
2969
|
+
ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false
|
|
1242
2970
|
},
|
|
1243
2971
|
/* .device = */
|
|
1244
2972
|
dev,
|
|
1245
|
-
/* .context = */
|
|
2973
|
+
/* .context = */
|
|
2974
|
+
NULL
|
|
1246
2975
|
};
|
|
1247
2976
|
|
|
1248
2977
|
return &ggml_backend_webgpu_buffer_type;
|
|
@@ -1283,14 +3012,16 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) {
|
|
|
1283
3012
|
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
|
1284
3013
|
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
|
1285
3014
|
|
|
1286
|
-
webgpu_context webgpu_ctx = ctx->webgpu_ctx;
|
|
1287
|
-
|
|
1288
3015
|
ggml_tensor * src0 = op->src[0];
|
|
1289
3016
|
ggml_tensor * src1 = op->src[1];
|
|
3017
|
+
ggml_tensor * src2 = op->src[2];
|
|
3018
|
+
|
|
1290
3019
|
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
|
|
1291
|
-
if (ggml_nbytes(op) >
|
|
1292
|
-
(src0 != nullptr &&
|
|
1293
|
-
|
|
3020
|
+
if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
|
|
3021
|
+
(src0 != nullptr &&
|
|
3022
|
+
ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
|
|
3023
|
+
(src1 != nullptr &&
|
|
3024
|
+
ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
|
|
1294
3025
|
return false;
|
|
1295
3026
|
}
|
|
1296
3027
|
|
|
@@ -1304,28 +3035,43 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|
|
1304
3035
|
supports_op = true;
|
|
1305
3036
|
break;
|
|
1306
3037
|
case GGML_OP_ADD:
|
|
3038
|
+
case GGML_OP_SUB:
|
|
1307
3039
|
case GGML_OP_MUL:
|
|
1308
|
-
|
|
1309
|
-
|
|
3040
|
+
case GGML_OP_DIV:
|
|
3041
|
+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
|
|
3042
|
+
(src1->type == op->type);
|
|
3043
|
+
break;
|
|
3044
|
+
case GGML_OP_CONCAT:
|
|
3045
|
+
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
|
|
3046
|
+
break;
|
|
3047
|
+
case GGML_OP_REPEAT:
|
|
3048
|
+
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16);
|
|
1310
3049
|
break;
|
|
1311
3050
|
case GGML_OP_CPY:
|
|
3051
|
+
case GGML_OP_CONT:
|
|
3052
|
+
supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
|
3053
|
+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
|
|
3054
|
+
(op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
|
|
3055
|
+
break;
|
|
1312
3056
|
case GGML_OP_SET_ROWS:
|
|
1313
|
-
supports_op = (op->type == GGML_TYPE_F16
|
|
3057
|
+
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
|
|
3058
|
+
(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
|
|
1314
3059
|
break;
|
|
1315
3060
|
case GGML_OP_GET_ROWS:
|
|
1316
|
-
if (
|
|
1317
|
-
op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) {
|
|
3061
|
+
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) {
|
|
1318
3062
|
supports_op = (op->type == GGML_TYPE_F32);
|
|
3063
|
+
} else if (src0->type == GGML_TYPE_I32) {
|
|
3064
|
+
supports_op = op->type == GGML_TYPE_I32;
|
|
1319
3065
|
}
|
|
1320
3066
|
break;
|
|
1321
3067
|
case GGML_OP_MUL_MAT:
|
|
1322
3068
|
{
|
|
1323
|
-
switch (
|
|
3069
|
+
switch (src1->type) {
|
|
1324
3070
|
case GGML_TYPE_F16:
|
|
1325
|
-
supports_op
|
|
3071
|
+
supports_op |= (src0->type == GGML_TYPE_F16);
|
|
1326
3072
|
break;
|
|
1327
3073
|
case GGML_TYPE_F32:
|
|
1328
|
-
switch (
|
|
3074
|
+
switch (src0->type) {
|
|
1329
3075
|
case GGML_TYPE_F32:
|
|
1330
3076
|
case GGML_TYPE_F16:
|
|
1331
3077
|
case GGML_TYPE_Q4_0:
|
|
@@ -1357,19 +3103,160 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|
|
1357
3103
|
}
|
|
1358
3104
|
break;
|
|
1359
3105
|
}
|
|
3106
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
|
3107
|
+
{
|
|
3108
|
+
#ifndef __EMSCRIPTEN__
|
|
3109
|
+
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
|
3110
|
+
break;
|
|
3111
|
+
}
|
|
3112
|
+
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
|
3113
|
+
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
|
3114
|
+
const bool has_mask = op->src[3] != nullptr;
|
|
3115
|
+
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
|
|
3116
|
+
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
|
|
3117
|
+
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
|
3118
|
+
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
|
3119
|
+
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
|
|
3120
|
+
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
|
|
3121
|
+
if (min_bytes > limit_bytes) {
|
|
3122
|
+
break;
|
|
3123
|
+
}
|
|
3124
|
+
|
|
3125
|
+
supports_op = src0->type == GGML_TYPE_F32 &&
|
|
3126
|
+
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
|
|
3127
|
+
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
|
|
3128
|
+
src2->type == src1->type && op->type == GGML_TYPE_F32;
|
|
3129
|
+
#endif
|
|
3130
|
+
break;
|
|
3131
|
+
}
|
|
1360
3132
|
case GGML_OP_RMS_NORM:
|
|
1361
|
-
supports_op = op->type == GGML_TYPE_F32 &&
|
|
3133
|
+
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
|
3134
|
+
break;
|
|
3135
|
+
case GGML_OP_ROPE:
|
|
3136
|
+
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
|
|
3137
|
+
break;
|
|
3138
|
+
case GGML_OP_GLU:
|
|
3139
|
+
switch (ggml_get_glu_op(op)) {
|
|
3140
|
+
case GGML_GLU_OP_REGLU:
|
|
3141
|
+
case GGML_GLU_OP_GEGLU:
|
|
3142
|
+
case GGML_GLU_OP_SWIGLU:
|
|
3143
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
3144
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
3145
|
+
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
|
|
3146
|
+
break;
|
|
3147
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
3148
|
+
supports_op = op->type == GGML_TYPE_F32;
|
|
3149
|
+
break;
|
|
3150
|
+
default:
|
|
3151
|
+
break;
|
|
3152
|
+
}
|
|
3153
|
+
break;
|
|
3154
|
+
case GGML_OP_SCALE:
|
|
3155
|
+
supports_op = op->type == GGML_TYPE_F32;
|
|
3156
|
+
break;
|
|
3157
|
+
case GGML_OP_SOFT_MAX:
|
|
3158
|
+
supports_op = op->type == GGML_TYPE_F32;
|
|
3159
|
+
break;
|
|
3160
|
+
case GGML_OP_UNARY:
|
|
3161
|
+
{
|
|
3162
|
+
const ggml_unary_op UNARY_OP = ggml_get_unary_op(op);
|
|
3163
|
+
|
|
3164
|
+
switch (UNARY_OP) {
|
|
3165
|
+
case GGML_UNARY_OP_ABS:
|
|
3166
|
+
case GGML_UNARY_OP_SGN:
|
|
3167
|
+
case GGML_UNARY_OP_NEG:
|
|
3168
|
+
case GGML_UNARY_OP_STEP:
|
|
3169
|
+
case GGML_UNARY_OP_TANH:
|
|
3170
|
+
case GGML_UNARY_OP_ELU:
|
|
3171
|
+
case GGML_UNARY_OP_RELU:
|
|
3172
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
3173
|
+
case GGML_UNARY_OP_GELU:
|
|
3174
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
|
3175
|
+
case GGML_UNARY_OP_SILU:
|
|
3176
|
+
case GGML_UNARY_OP_HARDSWISH:
|
|
3177
|
+
case GGML_UNARY_OP_HARDSIGMOID:
|
|
3178
|
+
case GGML_UNARY_OP_EXP:
|
|
3179
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
3180
|
+
case GGML_UNARY_OP_SOFTPLUS:
|
|
3181
|
+
case GGML_UNARY_OP_EXPM1:
|
|
3182
|
+
case GGML_UNARY_OP_FLOOR:
|
|
3183
|
+
case GGML_UNARY_OP_CEIL:
|
|
3184
|
+
case GGML_UNARY_OP_ROUND:
|
|
3185
|
+
case GGML_UNARY_OP_TRUNC:
|
|
3186
|
+
case GGML_UNARY_OP_XIELU:
|
|
3187
|
+
supports_op =
|
|
3188
|
+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
|
3189
|
+
break;
|
|
3190
|
+
default:
|
|
3191
|
+
break;
|
|
3192
|
+
}
|
|
3193
|
+
}
|
|
3194
|
+
break;
|
|
3195
|
+
case GGML_OP_CLAMP:
|
|
3196
|
+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
|
3197
|
+
break;
|
|
3198
|
+
case GGML_OP_FILL:
|
|
3199
|
+
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
|
3200
|
+
break;
|
|
3201
|
+
case GGML_OP_LOG:
|
|
3202
|
+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
|
3203
|
+
break;
|
|
3204
|
+
case GGML_OP_SQR:
|
|
3205
|
+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
|
3206
|
+
break;
|
|
3207
|
+
case GGML_OP_SQRT:
|
|
3208
|
+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
|
3209
|
+
break;
|
|
3210
|
+
case GGML_OP_SIN:
|
|
3211
|
+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
|
3212
|
+
break;
|
|
3213
|
+
case GGML_OP_COS:
|
|
3214
|
+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
|
3215
|
+
break;
|
|
3216
|
+
case GGML_OP_PAD:
|
|
3217
|
+
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
|
3218
|
+
break;
|
|
3219
|
+
case GGML_OP_ARGMAX:
|
|
3220
|
+
supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32;
|
|
3221
|
+
break;
|
|
3222
|
+
case GGML_OP_ARGSORT:
|
|
3223
|
+
supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
|
|
3224
|
+
break;
|
|
3225
|
+
case GGML_OP_TOP_K:
|
|
3226
|
+
supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
|
|
3227
|
+
break;
|
|
3228
|
+
case GGML_OP_CUMSUM:
|
|
3229
|
+
supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type;
|
|
3230
|
+
break;
|
|
3231
|
+
case GGML_OP_SUM:
|
|
3232
|
+
case GGML_OP_SUM_ROWS:
|
|
3233
|
+
supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0);
|
|
1362
3234
|
break;
|
|
1363
3235
|
default:
|
|
1364
3236
|
break;
|
|
1365
3237
|
}
|
|
1366
|
-
|
|
3238
|
+
if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
|
|
3239
|
+
(src0 != nullptr &&
|
|
3240
|
+
ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
|
|
3241
|
+
(src1 != nullptr &&
|
|
3242
|
+
ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
|
|
3243
|
+
(src2 != nullptr &&
|
|
3244
|
+
ggml_nbytes(src2) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
|
|
3245
|
+
supports_op = false;
|
|
3246
|
+
WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
|
|
3247
|
+
}
|
|
3248
|
+
|
|
1367
3249
|
if (!supports_op) {
|
|
1368
|
-
WEBGPU_LOG_DEBUG("
|
|
1369
|
-
|
|
1370
|
-
|
|
3250
|
+
WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
|
|
3251
|
+
<< ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
|
|
3252
|
+
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
|
|
3253
|
+
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
|
|
3254
|
+
} else {
|
|
3255
|
+
WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
|
|
3256
|
+
<< ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
|
|
3257
|
+
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
|
|
3258
|
+
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
|
|
1371
3259
|
}
|
|
1372
|
-
#endif
|
|
1373
3260
|
return supports_op;
|
|
1374
3261
|
}
|
|
1375
3262
|
|
|
@@ -1379,7 +3266,7 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
|
|
|
1379
3266
|
/* .get_memory = */ ggml_backend_webgpu_device_get_memory,
|
|
1380
3267
|
/* .get_type = */ ggml_backend_webgpu_device_get_type,
|
|
1381
3268
|
/* .get_props = */ ggml_backend_webgpu_device_get_props,
|
|
1382
|
-
/* .init_backend = */
|
|
3269
|
+
/* .init_backend = */ ggml_backend_webgpu_backend_init,
|
|
1383
3270
|
/* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type,
|
|
1384
3271
|
/* .get_host_buffer_type = */ NULL,
|
|
1385
3272
|
/* .buffer_from_host_ptr = */ NULL,
|
|
@@ -1405,113 +3292,29 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
|
|
|
1405
3292
|
return ctx->device_count;
|
|
1406
3293
|
}
|
|
1407
3294
|
|
|
1408
|
-
// TODO: Does this need to be thread safe? Is it only called once?
|
|
1409
3295
|
// Only one device is supported for now
|
|
1410
3296
|
static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
|
1411
3297
|
GGML_ASSERT(index == 0);
|
|
1412
3298
|
WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
|
|
1413
3299
|
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
webgpu_context ctx = reg_ctx->webgpu_ctx;
|
|
1417
|
-
|
|
1418
|
-
wgpu::RequestAdapterOptions options = {};
|
|
1419
|
-
ctx->instance.WaitAny(ctx->instance.RequestAdapter(
|
|
1420
|
-
&options, wgpu::CallbackMode::AllowSpontaneous,
|
|
1421
|
-
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
|
1422
|
-
if (status != wgpu::RequestAdapterStatus::Success) {
|
|
1423
|
-
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
|
1424
|
-
return;
|
|
1425
|
-
}
|
|
1426
|
-
ctx->adapter = std::move(adapter);
|
|
1427
|
-
}),
|
|
1428
|
-
UINT64_MAX);
|
|
1429
|
-
GGML_ASSERT(ctx->adapter != nullptr);
|
|
1430
|
-
|
|
1431
|
-
ctx->adapter.GetLimits(&ctx->limits);
|
|
1432
|
-
ctx->max_wg_size_x = 288; // default value
|
|
1433
|
-
|
|
1434
|
-
wgpu::AdapterInfo info{};
|
|
1435
|
-
ctx->adapter.GetInfo(&info);
|
|
3300
|
+
WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device);
|
|
1436
3301
|
|
|
1437
|
-
|
|
1438
|
-
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
|
|
1439
|
-
wgpu::FeatureName::ImplicitDeviceSynchronization };
|
|
1440
|
-
wgpu::DeviceDescriptor dev_desc;
|
|
1441
|
-
dev_desc.requiredLimits = &ctx->limits;
|
|
1442
|
-
dev_desc.requiredFeatures = required_features.data();
|
|
1443
|
-
dev_desc.requiredFeatureCount = required_features.size();
|
|
1444
|
-
dev_desc.SetDeviceLostCallback(
|
|
1445
|
-
wgpu::CallbackMode::AllowSpontaneous,
|
|
1446
|
-
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
|
1447
|
-
GGML_UNUSED(device);
|
|
1448
|
-
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
|
1449
|
-
std::string(message).c_str());
|
|
1450
|
-
});
|
|
1451
|
-
dev_desc.SetUncapturedErrorCallback(
|
|
1452
|
-
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
|
|
1453
|
-
GGML_UNUSED(device);
|
|
1454
|
-
GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
|
1455
|
-
std::string(message).c_str());
|
|
1456
|
-
});
|
|
1457
|
-
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
|
|
1458
|
-
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
|
|
1459
|
-
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
|
1460
|
-
if (status != wgpu::RequestDeviceStatus::Success) {
|
|
1461
|
-
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n",
|
|
1462
|
-
std::string(message).c_str());
|
|
1463
|
-
return;
|
|
1464
|
-
}
|
|
1465
|
-
ctx->device = std::move(device);
|
|
1466
|
-
}),
|
|
1467
|
-
UINT64_MAX);
|
|
1468
|
-
GGML_ASSERT(ctx->device != nullptr);
|
|
1469
|
-
|
|
1470
|
-
// Initialize (compute) queue
|
|
1471
|
-
ctx->queue = ctx->device.GetQueue();
|
|
1472
|
-
|
|
1473
|
-
// Create buffer pool for shader parameters
|
|
1474
|
-
ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
|
1475
|
-
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
|
1476
|
-
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
|
|
1477
|
-
ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
|
1478
|
-
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
|
1479
|
-
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
|
|
1480
|
-
|
|
1481
|
-
ggml_webgpu_init_memset_pipeline(ctx);
|
|
1482
|
-
ggml_webgpu_init_mul_mat_pipeline(ctx);
|
|
1483
|
-
ggml_webgpu_init_set_rows_pipeline(ctx);
|
|
1484
|
-
ggml_webgpu_init_get_rows_pipeline(ctx);
|
|
1485
|
-
ggml_webgpu_init_cpy_pipeline(ctx);
|
|
1486
|
-
ggml_webgpu_init_add_pipeline(ctx);
|
|
1487
|
-
ggml_webgpu_init_mul_pipeline(ctx);
|
|
1488
|
-
ggml_webgpu_init_rms_norm_pipeline(ctx);
|
|
3302
|
+
ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
|
|
1489
3303
|
|
|
1490
|
-
|
|
1491
|
-
// Initialize debug buffers
|
|
1492
|
-
ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
|
1493
|
-
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
|
|
1494
|
-
ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
|
1495
|
-
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
|
|
1496
|
-
#endif
|
|
3304
|
+
create_webgpu_device(reg_ctx);
|
|
1497
3305
|
|
|
1498
3306
|
static ggml_backend_webgpu_device_context device_ctx;
|
|
1499
|
-
device_ctx.
|
|
1500
|
-
device_ctx.
|
|
1501
|
-
device_ctx.
|
|
1502
|
-
|
|
1503
|
-
GGML_LOG_INFO(
|
|
1504
|
-
"ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
|
|
1505
|
-
"device_desc: %s\n",
|
|
1506
|
-
info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
|
|
1507
|
-
std::string(info.device).c_str(), std::string(info.description).c_str());
|
|
1508
|
-
|
|
3307
|
+
device_ctx.device_name = GGML_WEBGPU_NAME;
|
|
3308
|
+
device_ctx.device_desc = GGML_WEBGPU_NAME;
|
|
3309
|
+
device_ctx.webgpu_global_ctx = reg_ctx->webgpu_global_ctx;
|
|
1509
3310
|
// See GGML Backend Device Interface section
|
|
1510
3311
|
static ggml_backend_device device = {
|
|
1511
3312
|
/* .iface = */ ggml_backend_webgpu_device_i,
|
|
1512
3313
|
/* .reg = */ reg,
|
|
1513
3314
|
/* .context = */ &device_ctx,
|
|
1514
3315
|
};
|
|
3316
|
+
|
|
3317
|
+
WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, reg_ctx->webgpu_global_ctx);
|
|
1515
3318
|
return &device;
|
|
1516
3319
|
}
|
|
1517
3320
|
|
|
@@ -1527,10 +3330,7 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
|
|
|
1527
3330
|
ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
|
1528
3331
|
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
|
|
1529
3332
|
|
|
1530
|
-
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
|
|
1531
|
-
|
|
1532
3333
|
static ggml_backend_webgpu_reg_context ctx;
|
|
1533
|
-
ctx.webgpu_ctx = webgpu_ctx;
|
|
1534
3334
|
ctx.name = GGML_WEBGPU_NAME;
|
|
1535
3335
|
ctx.device_count = 1;
|
|
1536
3336
|
|
|
@@ -1538,8 +3338,26 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
|
|
1538
3338
|
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
|
|
1539
3339
|
instance_descriptor.requiredFeatures = instance_features.data();
|
|
1540
3340
|
instance_descriptor.requiredFeatureCount = instance_features.size();
|
|
1541
|
-
|
|
1542
|
-
|
|
3341
|
+
|
|
3342
|
+
#ifndef __EMSCRIPTEN__
|
|
3343
|
+
const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
|
|
3344
|
+
wgpu::DawnTogglesDescriptor instanceTogglesDesc;
|
|
3345
|
+
instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
|
|
3346
|
+
instanceTogglesDesc.enabledToggleCount = 1;
|
|
3347
|
+
instance_descriptor.nextInChain = &instanceTogglesDesc;
|
|
3348
|
+
#endif
|
|
3349
|
+
|
|
3350
|
+
wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor);
|
|
3351
|
+
ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
|
|
3352
|
+
ctx.webgpu_global_ctx->instance = std::move(inst);
|
|
3353
|
+
|
|
3354
|
+
#ifdef __EMSCRIPTEN__
|
|
3355
|
+
if (ctx.webgpu_global_ctx->instance == nullptr) {
|
|
3356
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
|
|
3357
|
+
return nullptr;
|
|
3358
|
+
}
|
|
3359
|
+
#endif
|
|
3360
|
+
GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);
|
|
1543
3361
|
|
|
1544
3362
|
static ggml_backend_reg reg = {
|
|
1545
3363
|
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
|
@@ -1552,7 +3370,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
|
|
1552
3370
|
ggml_backend_t ggml_backend_webgpu_init(void) {
|
|
1553
3371
|
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
|
|
1554
3372
|
|
|
1555
|
-
return
|
|
3373
|
+
return ggml_backend_webgpu_backend_init(dev, nullptr);
|
|
1556
3374
|
}
|
|
1557
3375
|
|
|
1558
3376
|
GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)
|