whispercpp 1.3.3 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +60 -43
- data/ext/extconf.rb +2 -2
- data/ext/ruby_whisper.c +14 -2
- data/ext/ruby_whisper.h +39 -0
- data/ext/ruby_whisper_context.c +22 -22
- data/ext/ruby_whisper_model.c +12 -12
- data/ext/ruby_whisper_params.c +79 -25
- data/ext/ruby_whisper_segment.c +84 -19
- data/ext/ruby_whisper_token.c +351 -0
- data/ext/ruby_whisper_transcribe.cpp +1 -1
- data/ext/ruby_whisper_vad_context.c +75 -0
- data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
- data/ext/ruby_whisper_vad_segment.c +139 -0
- data/ext/ruby_whisper_vad_segments.c +106 -0
- data/ext/sources/CMakeLists.txt +4 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
- data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
- data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
- data/ext/sources/examples/CMakeLists.txt +1 -0
- data/ext/sources/examples/addon.node/addon.cpp +19 -19
- data/ext/sources/examples/addon.node/index.js +7 -5
- data/ext/sources/examples/addon.node/vad-example.js +2 -2
- data/ext/sources/examples/bench/bench.cpp +26 -16
- data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
- data/ext/sources/examples/cli/cli.cpp +122 -111
- data/ext/sources/examples/command/command.cpp +26 -24
- data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
- data/ext/sources/examples/lsp/lsp.cpp +19 -17
- data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
- data/ext/sources/examples/server/server.cpp +34 -24
- data/ext/sources/examples/server.py +6 -1
- data/ext/sources/examples/stream/stream.cpp +4 -2
- data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
- data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
- data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
- data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
- data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
- data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
- data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
- data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
- data/ext/sources/examples/talk-llama/llama-context.h +99 -36
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
- data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
- data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
- data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
- data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
- data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
- data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
- data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
- data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
- data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
- data/ext/sources/examples/talk-llama/llama-model.h +104 -12
- data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
- data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
- data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
- data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
- data/ext/sources/examples/talk-llama/llama.cpp +794 -12
- data/ext/sources/examples/talk-llama/llama.h +246 -190
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
- data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
- data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
- data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
- data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
- data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
- data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
- data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
- data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
- data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
- data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
- data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
- data/ext/sources/examples/talk-llama/models/models.h +569 -0
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
- data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
- data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
- data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
- data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
- data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
- data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
- data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
- data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
- data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
- data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
- data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
- data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
- data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
- data/ext/sources/examples/talk-llama/unicode.h +45 -0
- data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
- data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
- data/ext/sources/ggml/CMakeLists.txt +135 -79
- data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
- data/ext/sources/ggml/include/ggml-alloc.h +9 -0
- data/ext/sources/ggml/include/ggml-backend.h +21 -2
- data/ext/sources/ggml/include/ggml-cpu.h +2 -1
- data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
- data/ext/sources/ggml/include/ggml-metal.h +1 -6
- data/ext/sources/ggml/include/ggml-opt.h +25 -6
- data/ext/sources/ggml/include/ggml-rpc.h +8 -11
- data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
- data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
- data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
- data/ext/sources/ggml/include/ggml.h +406 -23
- data/ext/sources/ggml/src/CMakeLists.txt +99 -13
- data/ext/sources/ggml/src/ggml-alloc.c +368 -161
- data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
- data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
- data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
- data/ext/sources/ggml/src/ggml-common.h +17 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
- data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
- data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
- data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
- data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
- data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
- data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
- data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
- data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
- data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
- data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
- data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
- data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
- data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
- data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
- data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
- data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
- data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
- data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
- data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
- data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
- data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
- data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
- data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
- data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
- data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
- data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
- data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
- data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
- data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
- data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
- data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
- data/ext/sources/ggml/src/ggml-impl.h +186 -15
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
- data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
- data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
- data/ext/sources/ggml/src/ggml-quants.c +111 -16
- data/ext/sources/ggml/src/ggml-quants.h +6 -0
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
- data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
- data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
- data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
- data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
- data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
- data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
- data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
- data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
- data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
- data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
- data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
- data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
- data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
- data/ext/sources/ggml/src/ggml.c +901 -129
- data/ext/sources/ggml/src/gguf.cpp +8 -1
- data/ext/sources/include/whisper.h +1 -0
- data/ext/sources/src/CMakeLists.txt +3 -1
- data/ext/sources/src/whisper.cpp +124 -81
- data/ext/sources/tests/CMakeLists.txt +8 -1
- data/ext/sources/tests/test-vad-full.cpp +7 -5
- data/ext/sources/tests/test-vad.cpp +3 -3
- data/extsources.rb +1 -0
- data/lib/whisper/model/uri.rb +17 -18
- data/sig/whisper.rbs +126 -2
- data/test/test_params.rb +24 -8
- data/test/test_segment.rb +0 -1
- data/test/test_token.rb +70 -0
- data/test/test_vad.rb +1 -1
- data/test/test_vad_context.rb +50 -0
- data/test/test_vad_segment.rb +19 -0
- data/test/test_vad_segments.rb +16 -0
- data/test/test_whisper.rb +8 -1
- data/whispercpp.gemspec +1 -1
- metadata +439 -179
- data/ext/sources/build-xcframework.sh +0 -547
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
- data/ext/sources/ggml/include/ggml-kompute.h +0 -50
- data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
- data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
- data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
- data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
- data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
- data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
|
@@ -0,0 +1,3087 @@
|
|
|
1
|
+
/*
|
|
2
|
+
WebGPU backend implementation.
|
|
3
|
+
Note: Use ClangFormat to format this file.
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
#include "ggml-webgpu.h"
|
|
7
|
+
|
|
8
|
+
#include "ggml-backend-impl.h"
|
|
9
|
+
#include "ggml-impl.h"
|
|
10
|
+
#include "ggml-webgpu-shader-lib.hpp"
|
|
11
|
+
#include "ggml-wgsl-shaders.hpp"
|
|
12
|
+
#include "pre_wgsl.hpp"
|
|
13
|
+
|
|
14
|
+
#ifdef __EMSCRIPTEN__
|
|
15
|
+
# include <emscripten/emscripten.h>
|
|
16
|
+
#endif
|
|
17
|
+
|
|
18
|
+
#include <webgpu/webgpu_cpp.h>
|
|
19
|
+
|
|
20
|
+
#include <atomic>
|
|
21
|
+
#include <condition_variable>
|
|
22
|
+
#include <cstdint>
|
|
23
|
+
#include <cstring>
|
|
24
|
+
#include <iostream>
|
|
25
|
+
#include <map>
|
|
26
|
+
#include <mutex>
|
|
27
|
+
#include <optional>
|
|
28
|
+
#include <string>
|
|
29
|
+
#include <vector>
|
|
30
|
+
|
|
31
|
+
#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
|
|
32
|
+
#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
|
|
33
|
+
|
|
34
|
+
#ifdef GGML_WEBGPU_DEBUG
|
|
35
|
+
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
|
|
36
|
+
# define WEBGPU_DEBUG_BUF_ELEMS 512
|
|
37
|
+
#else
|
|
38
|
+
# define WEBGPU_LOG_DEBUG(msg) ((void) 0)
|
|
39
|
+
#endif // GGML_WEBGPU_DEBUG
|
|
40
|
+
|
|
41
|
+
#ifdef GGML_WEBGPU_CPU_PROFILE
|
|
42
|
+
// total timing (aggregated)
|
|
43
|
+
# define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now();
|
|
44
|
+
|
|
45
|
+
# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) \
|
|
46
|
+
auto cpu_total_end_##id = std::chrono::high_resolution_clock::now(); \
|
|
47
|
+
double cpu_total_time_##id = \
|
|
48
|
+
std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \
|
|
49
|
+
(ctx)->cpu_time_ms[#id] += cpu_total_time_##id;
|
|
50
|
+
|
|
51
|
+
// fine-grained timing (not included in totals)
|
|
52
|
+
# define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();
|
|
53
|
+
|
|
54
|
+
# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) \
|
|
55
|
+
auto cpu_detail_end_##id = std::chrono::high_resolution_clock::now(); \
|
|
56
|
+
double cpu_detail_time_##id = \
|
|
57
|
+
std::chrono::duration<double, std::milli>(cpu_detail_end_##id - cpu_detail_start_##id).count(); \
|
|
58
|
+
(ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id;
|
|
59
|
+
#else
|
|
60
|
+
# define WEBGPU_CPU_PROFILE_TOTAL_START(id)
|
|
61
|
+
# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx)
|
|
62
|
+
# define WEBGPU_CPU_PROFILE_DETAIL_START(id)
|
|
63
|
+
# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx)
|
|
64
|
+
#endif // GGML_WEBGPU_CPU_PROFILE
|
|
65
|
+
|
|
66
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
67
|
+
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
|
|
68
|
+
# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
|
|
69
|
+
#endif
|
|
70
|
+
|
|
71
|
+
/* Constants */
|
|
72
|
+
|
|
73
|
+
// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed.
|
|
74
|
+
#define WEBGPU_MAX_WG_SIZE 288
|
|
75
|
+
|
|
76
|
+
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
|
77
|
+
#define WEBGPU_NUM_PARAM_BUFS 32u
|
|
78
|
+
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
|
|
79
|
+
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
|
|
80
|
+
// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
|
|
81
|
+
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
|
|
82
|
+
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
|
|
83
|
+
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
|
|
84
|
+
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
|
|
85
|
+
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
|
86
|
+
|
|
87
|
+
// For operations which process a row in parallel, this seems like a reasonable default
|
|
88
|
+
#define WEBGPU_ROW_SPLIT_WG_SIZE 64
|
|
89
|
+
|
|
90
|
+
// Matrix multiplication parameters
|
|
91
|
+
|
|
92
|
+
// Register tiling parameters
|
|
93
|
+
#define WEBGPU_MUL_MAT_TILE_M 8
|
|
94
|
+
#define WEBGPU_MUL_MAT_TILE_N 8
|
|
95
|
+
#define WEBGPU_MUL_MAT_WG_SIZE_M 8
|
|
96
|
+
#define WEBGPU_MUL_MAT_WG_SIZE_N 8
|
|
97
|
+
#define WEBGPU_MUL_MAT_TILE_K 32
|
|
98
|
+
|
|
99
|
+
// Subgroup matrix parameters
|
|
100
|
+
// The number of subgroups in the M dimension
|
|
101
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_M 2
|
|
102
|
+
// The number of subgroups in the N dimension
|
|
103
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_N 2
|
|
104
|
+
// The number of subgroup matrices each subgroup accumulates over
|
|
105
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
|
|
106
|
+
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
|
107
|
+
|
|
108
|
+
// Matrix-vector multiplication parameters
|
|
109
|
+
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
|
110
|
+
// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
|
|
111
|
+
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
|
|
112
|
+
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
|
|
113
|
+
|
|
114
|
+
/* End Constants */
|
|
115
|
+
|
|
116
|
+
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
|
|
117
|
+
static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
|
|
118
|
+
|
|
119
|
+
// Always returns the base offset of a tensor, regardless of views.
|
|
120
|
+
static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
|
|
121
|
+
if (tensor->view_src) {
|
|
122
|
+
return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
|
|
123
|
+
}
|
|
124
|
+
return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
/* Struct definitions */
|
|
128
|
+
|
|
129
|
+
// Forward reference
|
|
130
|
+
static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
|
131
|
+
wgpu::Buffer & buffer,
|
|
132
|
+
size_t size,
|
|
133
|
+
wgpu::BufferUsage usage,
|
|
134
|
+
const char * label);
|
|
135
|
+
|
|
136
|
+
struct webgpu_pool_bufs {
|
|
137
|
+
wgpu::Buffer host_buf;
|
|
138
|
+
wgpu::Buffer dev_buf;
|
|
139
|
+
};
|
|
140
|
+
|
|
141
|
+
// The futures to wait on for a single queue submission
|
|
142
|
+
struct webgpu_submission_futures {
|
|
143
|
+
std::vector<wgpu::FutureWaitInfo> futures;
|
|
144
|
+
};
|
|
145
|
+
|
|
146
|
+
// Holds a pool of parameter buffers for WebGPU operations
|
|
147
|
+
struct webgpu_buf_pool {
|
|
148
|
+
std::vector<webgpu_pool_bufs> free;
|
|
149
|
+
|
|
150
|
+
std::mutex mutex;
|
|
151
|
+
|
|
152
|
+
std::condition_variable cv;
|
|
153
|
+
|
|
154
|
+
void init(wgpu::Device device,
|
|
155
|
+
int num_bufs,
|
|
156
|
+
size_t buf_size,
|
|
157
|
+
wgpu::BufferUsage dev_buf_usage,
|
|
158
|
+
wgpu::BufferUsage host_buf_usage) {
|
|
159
|
+
for (int i = 0; i < num_bufs; i++) {
|
|
160
|
+
wgpu::Buffer host_buf;
|
|
161
|
+
wgpu::Buffer dev_buf;
|
|
162
|
+
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
|
|
163
|
+
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
|
|
164
|
+
free.push_back({ host_buf, dev_buf });
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
webgpu_pool_bufs alloc_bufs() {
|
|
169
|
+
std::unique_lock<std::mutex> lock(mutex);
|
|
170
|
+
cv.wait(lock, [this] { return !free.empty(); });
|
|
171
|
+
webgpu_pool_bufs bufs = free.back();
|
|
172
|
+
free.pop_back();
|
|
173
|
+
return bufs;
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
|
|
177
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
178
|
+
free.insert(free.end(), bufs.begin(), bufs.end());
|
|
179
|
+
cv.notify_all();
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
void cleanup() {
|
|
183
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
184
|
+
for (auto & bufs : free) {
|
|
185
|
+
bufs.host_buf.Destroy();
|
|
186
|
+
bufs.dev_buf.Destroy();
|
|
187
|
+
}
|
|
188
|
+
free.clear();
|
|
189
|
+
}
|
|
190
|
+
};
|
|
191
|
+
|
|
192
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
193
|
+
struct webgpu_gpu_profile_bufs {
|
|
194
|
+
wgpu::Buffer host_buf;
|
|
195
|
+
wgpu::Buffer dev_buf;
|
|
196
|
+
wgpu::QuerySet query_set;
|
|
197
|
+
};
|
|
198
|
+
|
|
199
|
+
// Holds a pool of parameter buffers for WebGPU operations
|
|
200
|
+
struct webgpu_gpu_profile_buf_pool {
|
|
201
|
+
std::vector<webgpu_gpu_profile_bufs> free;
|
|
202
|
+
|
|
203
|
+
std::mutex mutex;
|
|
204
|
+
|
|
205
|
+
std::condition_variable cv;
|
|
206
|
+
|
|
207
|
+
void init(wgpu::Device device,
|
|
208
|
+
int num_bufs,
|
|
209
|
+
size_t buf_size,
|
|
210
|
+
wgpu::BufferUsage dev_buf_usage,
|
|
211
|
+
wgpu::BufferUsage host_buf_usage) {
|
|
212
|
+
for (int i = 0; i < num_bufs; i++) {
|
|
213
|
+
wgpu::Buffer host_buf;
|
|
214
|
+
wgpu::Buffer dev_buf;
|
|
215
|
+
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf");
|
|
216
|
+
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf");
|
|
217
|
+
// Create a query set for 2 timestamps
|
|
218
|
+
wgpu::QuerySetDescriptor ts_query_set_desc = {};
|
|
219
|
+
|
|
220
|
+
ts_query_set_desc.type = wgpu::QueryType::Timestamp;
|
|
221
|
+
ts_query_set_desc.count = 2;
|
|
222
|
+
wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc);
|
|
223
|
+
|
|
224
|
+
free.push_back({ host_buf, dev_buf, ts_query_set });
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
webgpu_gpu_profile_bufs alloc_bufs() {
|
|
229
|
+
std::unique_lock<std::mutex> lock(mutex);
|
|
230
|
+
cv.wait(lock, [this] { return !free.empty(); });
|
|
231
|
+
webgpu_gpu_profile_bufs bufs = free.back();
|
|
232
|
+
free.pop_back();
|
|
233
|
+
return bufs;
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
void free_bufs(std::vector<webgpu_gpu_profile_bufs> bufs) {
|
|
237
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
238
|
+
free.insert(free.end(), bufs.begin(), bufs.end());
|
|
239
|
+
cv.notify_all();
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
void cleanup() {
|
|
243
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
244
|
+
for (auto & bufs : free) {
|
|
245
|
+
bufs.host_buf.Destroy();
|
|
246
|
+
bufs.dev_buf.Destroy();
|
|
247
|
+
bufs.query_set.Destroy();
|
|
248
|
+
}
|
|
249
|
+
free.clear();
|
|
250
|
+
}
|
|
251
|
+
};
|
|
252
|
+
#endif
|
|
253
|
+
|
|
254
|
+
struct webgpu_pipeline {
|
|
255
|
+
wgpu::ComputePipeline pipeline;
|
|
256
|
+
std::string name;
|
|
257
|
+
void * context = nullptr;
|
|
258
|
+
};
|
|
259
|
+
|
|
260
|
+
struct webgpu_command {
|
|
261
|
+
wgpu::CommandBuffer commands;
|
|
262
|
+
webgpu_pool_bufs params_bufs;
|
|
263
|
+
std::optional<webgpu_pool_bufs> set_rows_error_bufs;
|
|
264
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
265
|
+
webgpu_gpu_profile_bufs timestamp_query_bufs;
|
|
266
|
+
std::string pipeline_name;
|
|
267
|
+
#endif
|
|
268
|
+
};
|
|
269
|
+
|
|
270
|
+
struct flash_attn_pipeline_key {
|
|
271
|
+
int q_type;
|
|
272
|
+
int kv_type;
|
|
273
|
+
int dst_type;
|
|
274
|
+
uint32_t head_dim_qk;
|
|
275
|
+
uint32_t head_dim_v;
|
|
276
|
+
bool kv_direct;
|
|
277
|
+
bool has_mask;
|
|
278
|
+
bool has_sinks;
|
|
279
|
+
bool uses_logit_softcap;
|
|
280
|
+
|
|
281
|
+
bool operator==(const flash_attn_pipeline_key & other) const {
|
|
282
|
+
return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
|
|
283
|
+
head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
|
|
284
|
+
has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
|
285
|
+
uses_logit_softcap == other.uses_logit_softcap;
|
|
286
|
+
}
|
|
287
|
+
};
|
|
288
|
+
|
|
289
|
+
// Same hash combine function as in boost
|
|
290
|
+
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
|
|
291
|
+
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
struct flash_attn_pipeline_key_hash {
|
|
295
|
+
size_t operator()(const flash_attn_pipeline_key & key) const {
|
|
296
|
+
size_t seed = 0;
|
|
297
|
+
ggml_webgpu_hash_combine(seed, key.q_type);
|
|
298
|
+
ggml_webgpu_hash_combine(seed, key.kv_type);
|
|
299
|
+
ggml_webgpu_hash_combine(seed, key.dst_type);
|
|
300
|
+
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
|
301
|
+
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
|
302
|
+
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
|
303
|
+
ggml_webgpu_hash_combine(seed, key.has_mask);
|
|
304
|
+
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
|
305
|
+
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
|
306
|
+
return seed;
|
|
307
|
+
}
|
|
308
|
+
};
|
|
309
|
+
|
|
310
|
+
// All the base objects needed to run operations on a WebGPU device
|
|
311
|
+
struct webgpu_context_struct {
|
|
312
|
+
wgpu::Instance instance;
|
|
313
|
+
wgpu::Adapter adapter;
|
|
314
|
+
wgpu::Device device;
|
|
315
|
+
wgpu::Queue queue;
|
|
316
|
+
wgpu::Limits limits;
|
|
317
|
+
|
|
318
|
+
uint32_t max_subgroup_size;
|
|
319
|
+
|
|
320
|
+
bool supports_subgroup_matrix = false;
|
|
321
|
+
uint32_t sg_mat_m;
|
|
322
|
+
uint32_t sg_mat_n;
|
|
323
|
+
uint32_t sg_mat_k;
|
|
324
|
+
|
|
325
|
+
std::recursive_mutex mutex;
|
|
326
|
+
std::atomic_uint inflight_threads = 0;
|
|
327
|
+
|
|
328
|
+
webgpu_buf_pool param_buf_pool;
|
|
329
|
+
webgpu_buf_pool set_rows_error_buf_pool;
|
|
330
|
+
|
|
331
|
+
pre_wgsl::Preprocessor p;
|
|
332
|
+
|
|
333
|
+
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
|
|
334
|
+
|
|
335
|
+
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
|
|
336
|
+
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
|
|
337
|
+
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
|
|
338
|
+
|
|
339
|
+
std::unordered_map<flash_attn_pipeline_key, webgpu_pipeline, flash_attn_pipeline_key_hash> flash_attn_pipelines;
|
|
340
|
+
|
|
341
|
+
std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
|
|
342
|
+
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
|
|
343
|
+
|
|
344
|
+
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
|
345
|
+
std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
|
|
346
|
+
std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
|
|
347
|
+
std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
|
|
348
|
+
std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
|
|
349
|
+
|
|
350
|
+
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
|
|
351
|
+
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
|
|
352
|
+
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
|
|
353
|
+
std::map<int, webgpu_pipeline> scale_pipelines; // inplace
|
|
354
|
+
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
|
|
355
|
+
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> unary_pipelines; // unary_op, type, inplace
|
|
356
|
+
|
|
357
|
+
size_t memset_bytes_per_thread;
|
|
358
|
+
|
|
359
|
+
// Staging buffer for reading data from the GPU
|
|
360
|
+
wgpu::Buffer get_tensor_staging_buf;
|
|
361
|
+
|
|
362
|
+
#ifdef GGML_WEBGPU_DEBUG
|
|
363
|
+
wgpu::Buffer debug_host_buf;
|
|
364
|
+
wgpu::Buffer debug_dev_buf;
|
|
365
|
+
#endif
|
|
366
|
+
|
|
367
|
+
#ifdef GGML_WEBGPU_CPU_PROFILE
|
|
368
|
+
// Profiling: labeled CPU time in ms (total)
|
|
369
|
+
std::unordered_map<std::string, double> cpu_time_ms;
|
|
370
|
+
// Profiling: detailed CPU time in ms
|
|
371
|
+
std::unordered_map<std::string, double> cpu_detail_ms;
|
|
372
|
+
#endif
|
|
373
|
+
|
|
374
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
375
|
+
// Profiling: per-shader GPU time in ms
|
|
376
|
+
std::unordered_map<std::string, double> shader_gpu_time_ms;
|
|
377
|
+
// Profiling: pool of timestamp query buffers (one per operation)
|
|
378
|
+
webgpu_gpu_profile_buf_pool timestamp_query_buf_pool;
|
|
379
|
+
#endif
|
|
380
|
+
};
|
|
381
|
+
|
|
382
|
+
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
|
|
383
|
+
|
|
384
|
+
struct ggml_backend_webgpu_reg_context {
|
|
385
|
+
webgpu_context webgpu_ctx;
|
|
386
|
+
size_t device_count;
|
|
387
|
+
const char * name;
|
|
388
|
+
};
|
|
389
|
+
|
|
390
|
+
struct ggml_backend_webgpu_device_context {
|
|
391
|
+
webgpu_context webgpu_ctx;
|
|
392
|
+
std::string device_name;
|
|
393
|
+
std::string device_desc;
|
|
394
|
+
};
|
|
395
|
+
|
|
396
|
+
struct ggml_backend_webgpu_context {
|
|
397
|
+
webgpu_context webgpu_ctx;
|
|
398
|
+
std::string name;
|
|
399
|
+
};
|
|
400
|
+
|
|
401
|
+
struct ggml_backend_webgpu_buffer_context {
|
|
402
|
+
webgpu_context webgpu_ctx;
|
|
403
|
+
wgpu::Buffer buffer;
|
|
404
|
+
std::string label;
|
|
405
|
+
|
|
406
|
+
ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) :
|
|
407
|
+
webgpu_ctx(std::move(ctx)),
|
|
408
|
+
buffer(std::move(buf)),
|
|
409
|
+
label(std::move(lbl)) {}
|
|
410
|
+
};
|
|
411
|
+
|
|
412
|
+
/* WebGPU object initializations */
|
|
413
|
+
|
|
414
|
+
// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
|
|
415
|
+
// the corresponding values provided in `repls`.
|
|
416
|
+
static std::string ggml_webgpu_process_shader_repls(const char * src,
|
|
417
|
+
const std::map<std::string, std::string> & repls) {
|
|
418
|
+
if (!src) {
|
|
419
|
+
return std::string();
|
|
420
|
+
}
|
|
421
|
+
std::string s = src;
|
|
422
|
+
for (const auto & kv : repls) {
|
|
423
|
+
std::string token = "{{" + kv.first + "}}";
|
|
424
|
+
size_t pos = 0;
|
|
425
|
+
while ((pos = s.find(token, pos)) != std::string::npos) {
|
|
426
|
+
s.replace(pos, token.length(), kv.second);
|
|
427
|
+
pos += kv.second.length();
|
|
428
|
+
}
|
|
429
|
+
}
|
|
430
|
+
return s;
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
|
|
434
|
+
const char * shader_code,
|
|
435
|
+
const char * label,
|
|
436
|
+
const std::vector<wgpu::ConstantEntry> & constants = {}) {
|
|
437
|
+
wgpu::ShaderSourceWGSL shader_source;
|
|
438
|
+
shader_source.code = shader_code;
|
|
439
|
+
|
|
440
|
+
wgpu::ShaderModuleDescriptor shader_desc;
|
|
441
|
+
shader_desc.nextInChain = &shader_source;
|
|
442
|
+
|
|
443
|
+
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
|
|
444
|
+
|
|
445
|
+
wgpu::ComputePipelineDescriptor pipeline_desc;
|
|
446
|
+
pipeline_desc.label = label;
|
|
447
|
+
pipeline_desc.compute.module = shader_module;
|
|
448
|
+
pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
|
|
449
|
+
pipeline_desc.layout = nullptr; // nullptr means auto layout
|
|
450
|
+
if (constants.size() > 0) {
|
|
451
|
+
pipeline_desc.compute.constants = constants.data();
|
|
452
|
+
pipeline_desc.compute.constantCount = constants.size();
|
|
453
|
+
}
|
|
454
|
+
return { device.CreateComputePipeline(&pipeline_desc), label };
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
|
458
|
+
wgpu::Buffer & buffer,
|
|
459
|
+
size_t size,
|
|
460
|
+
wgpu::BufferUsage usage,
|
|
461
|
+
const char * label) {
|
|
462
|
+
wgpu::BufferDescriptor buffer_desc;
|
|
463
|
+
buffer_desc.size = size;
|
|
464
|
+
buffer_desc.usage = usage;
|
|
465
|
+
buffer_desc.label = label;
|
|
466
|
+
buffer_desc.mappedAtCreation = false;
|
|
467
|
+
|
|
468
|
+
// TODO: error handling
|
|
469
|
+
buffer = device.CreateBuffer(&buffer_desc);
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
/** End WebGPU object initializations */
|
|
473
|
+
|
|
474
|
+
/** WebGPU Actions */
|
|
475
|
+
|
|
476
|
+
// Wait for the queue to finish processing all submitted work
|
|
477
|
+
static void ggml_backend_webgpu_wait(webgpu_context & ctx,
|
|
478
|
+
std::vector<webgpu_submission_futures> & futures,
|
|
479
|
+
bool block = true) {
|
|
480
|
+
// If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
|
|
481
|
+
// inflight_max may be 0, meaning that we must wait on all futures.
|
|
482
|
+
uint64_t timeout_ms = block ? UINT64_MAX : 0;
|
|
483
|
+
uint32_t inflight_threads = ctx->inflight_threads;
|
|
484
|
+
uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
|
|
485
|
+
while (futures.size() >= inflight_max && futures.size() > 0) {
|
|
486
|
+
ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
|
|
487
|
+
futures.erase(futures.begin());
|
|
488
|
+
}
|
|
489
|
+
size_t i = 0;
|
|
490
|
+
while (i < futures.size()) {
|
|
491
|
+
auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
|
|
492
|
+
switch (waitStatus) {
|
|
493
|
+
case wgpu::WaitStatus::Success:
|
|
494
|
+
futures.erase(futures.begin() + i);
|
|
495
|
+
break;
|
|
496
|
+
case wgpu::WaitStatus::TimedOut:
|
|
497
|
+
i++;
|
|
498
|
+
break;
|
|
499
|
+
case wgpu::WaitStatus::Error:
|
|
500
|
+
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
|
501
|
+
break;
|
|
502
|
+
default:
|
|
503
|
+
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
|
|
504
|
+
break;
|
|
505
|
+
}
|
|
506
|
+
}
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
|
|
510
|
+
wgpu::Buffer & buffer,
|
|
511
|
+
wgpu::MapMode mode,
|
|
512
|
+
size_t offset,
|
|
513
|
+
size_t size) {
|
|
514
|
+
ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
|
|
515
|
+
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
|
516
|
+
if (status != wgpu::MapAsyncStatus::Success) {
|
|
517
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
|
|
518
|
+
message.data);
|
|
519
|
+
}
|
|
520
|
+
}),
|
|
521
|
+
UINT64_MAX);
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
#ifdef GGML_WEBGPU_DEBUG
|
|
525
|
+
// This function adds debugging information to shaders, as WebGPU does not support printing directly.
|
|
526
|
+
// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
|
|
527
|
+
// debug statements in the shader, and then call this function after encoding the commands and submitting them.
|
|
528
|
+
static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
|
|
529
|
+
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
|
530
|
+
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
|
|
531
|
+
wgpu::CommandBuffer commands = encoder.Finish();
|
|
532
|
+
ctx->queue.Submit(1, &commands);
|
|
533
|
+
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
|
|
534
|
+
const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
|
|
535
|
+
std::cout << "debug[0]: " << debug_data[0] << "\n";
|
|
536
|
+
ctx->debug_host_buf.Unmap();
|
|
537
|
+
}
|
|
538
|
+
#endif
|
|
539
|
+
|
|
540
|
+
static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector<webgpu_command> commands) {
|
|
541
|
+
std::vector<wgpu::CommandBuffer> command_buffers;
|
|
542
|
+
std::vector<webgpu_pool_bufs> params_bufs;
|
|
543
|
+
std::vector<webgpu_pool_bufs> set_rows_error_bufs;
|
|
544
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
545
|
+
std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
|
|
546
|
+
#endif
|
|
547
|
+
|
|
548
|
+
for (const auto & command : commands) {
|
|
549
|
+
command_buffers.push_back(command.commands);
|
|
550
|
+
params_bufs.push_back(command.params_bufs);
|
|
551
|
+
if (command.set_rows_error_bufs) {
|
|
552
|
+
set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
|
|
553
|
+
}
|
|
554
|
+
}
|
|
555
|
+
ctx->queue.Submit(command_buffers.size(), command_buffers.data());
|
|
556
|
+
|
|
557
|
+
std::vector<wgpu::FutureWaitInfo> futures;
|
|
558
|
+
|
|
559
|
+
wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
|
|
560
|
+
wgpu::CallbackMode::AllowSpontaneous,
|
|
561
|
+
[ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
|
562
|
+
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
|
563
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
|
|
564
|
+
}
|
|
565
|
+
// Free the staged buffers
|
|
566
|
+
ctx->param_buf_pool.free_bufs({ params_bufs });
|
|
567
|
+
});
|
|
568
|
+
futures.push_back({ p_f });
|
|
569
|
+
|
|
570
|
+
for (const auto & bufs : set_rows_error_bufs) {
|
|
571
|
+
wgpu::Future f = bufs.host_buf.MapAsync(
|
|
572
|
+
wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
|
|
573
|
+
[ctx, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
|
574
|
+
if (status != wgpu::MapAsyncStatus::Success) {
|
|
575
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
|
|
576
|
+
} else {
|
|
577
|
+
const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
|
|
578
|
+
if (*error_data) {
|
|
579
|
+
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
|
|
580
|
+
}
|
|
581
|
+
// We can't unmap in here due to WebGPU reentrancy limitations.
|
|
582
|
+
ctx->set_rows_error_buf_pool.free_bufs({ bufs });
|
|
583
|
+
}
|
|
584
|
+
});
|
|
585
|
+
futures.push_back({ f });
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
589
|
+
for (const auto & command : commands) {
|
|
590
|
+
auto label = command.pipeline_name;
|
|
591
|
+
auto ts_bufs = command.timestamp_query_bufs;
|
|
592
|
+
|
|
593
|
+
wgpu::Future f = ts_bufs.host_buf.MapAsync(
|
|
594
|
+
wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
|
|
595
|
+
[ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
|
596
|
+
if (status != wgpu::MapAsyncStatus::Success) {
|
|
597
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
|
|
598
|
+
} else {
|
|
599
|
+
const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange();
|
|
600
|
+
// WebGPU timestamps are in ns; convert to ms
|
|
601
|
+
double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
|
|
602
|
+
ctx->shader_gpu_time_ms[label] += elapsed_ms;
|
|
603
|
+
// We can't unmap in here due to WebGPU reentrancy limitations.
|
|
604
|
+
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
|
|
605
|
+
}
|
|
606
|
+
});
|
|
607
|
+
futures.push_back({ f });
|
|
608
|
+
}
|
|
609
|
+
#endif
|
|
610
|
+
return { futures };
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx,
|
|
614
|
+
webgpu_pipeline & pipeline,
|
|
615
|
+
std::vector<uint32_t> params,
|
|
616
|
+
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
|
617
|
+
uint32_t wg_x,
|
|
618
|
+
uint32_t wg_y = 1,
|
|
619
|
+
std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
|
|
620
|
+
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
|
|
621
|
+
|
|
622
|
+
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
|
|
623
|
+
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
|
|
624
|
+
for (size_t i = 0; i < params.size(); i++) {
|
|
625
|
+
_params[i] = params[i];
|
|
626
|
+
};
|
|
627
|
+
|
|
628
|
+
params_bufs.host_buf.Unmap();
|
|
629
|
+
|
|
630
|
+
uint32_t params_bufs_binding_num = bind_group_entries.size();
|
|
631
|
+
bind_group_entries.push_back({ .binding = params_bufs_binding_num,
|
|
632
|
+
.buffer = params_bufs.dev_buf,
|
|
633
|
+
.offset = 0,
|
|
634
|
+
.size = params_bufs.dev_buf.GetSize() });
|
|
635
|
+
|
|
636
|
+
wgpu::BindGroupDescriptor bind_group_desc;
|
|
637
|
+
bind_group_desc.layout = pipeline.pipeline.GetBindGroupLayout(0);
|
|
638
|
+
bind_group_desc.entryCount = bind_group_entries.size();
|
|
639
|
+
bind_group_desc.entries = bind_group_entries.data();
|
|
640
|
+
bind_group_desc.label = pipeline.name.c_str();
|
|
641
|
+
wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
|
|
642
|
+
|
|
643
|
+
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
|
644
|
+
encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
|
|
645
|
+
|
|
646
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
647
|
+
// --- Profiling: GPU timestamp queries ---
|
|
648
|
+
// Allocate a timestamp query buffer (2 timestamps: start/end)
|
|
649
|
+
webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
|
|
650
|
+
if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
|
|
651
|
+
ts_bufs.host_buf.Unmap();
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set,
|
|
655
|
+
.beginningOfPassWriteIndex = 0,
|
|
656
|
+
.endOfPassWriteIndex = 1 };
|
|
657
|
+
wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes };
|
|
658
|
+
wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc);
|
|
659
|
+
#else
|
|
660
|
+
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
|
661
|
+
#endif
|
|
662
|
+
pass.SetPipeline(pipeline.pipeline);
|
|
663
|
+
pass.SetBindGroup(0, bind_group);
|
|
664
|
+
pass.DispatchWorkgroups(wg_x, wg_y, 1);
|
|
665
|
+
pass.End();
|
|
666
|
+
|
|
667
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
668
|
+
// Resolve the query set into the device buffer
|
|
669
|
+
encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
|
|
670
|
+
encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
|
|
671
|
+
#endif
|
|
672
|
+
|
|
673
|
+
// If there are SET_ROWS operations in this submission, copy their error buffers to the host.
|
|
674
|
+
if (set_rows_error_bufs) {
|
|
675
|
+
encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
|
|
676
|
+
set_rows_error_bufs->host_buf.GetSize());
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
wgpu::CommandBuffer commands = encoder.Finish();
|
|
680
|
+
webgpu_command result = {};
|
|
681
|
+
result.commands = commands;
|
|
682
|
+
result.params_bufs = params_bufs;
|
|
683
|
+
result.set_rows_error_bufs = set_rows_error_bufs;
|
|
684
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
685
|
+
result.timestamp_query_bufs = ts_bufs;
|
|
686
|
+
result.pipeline_name = pipeline.name;
|
|
687
|
+
#endif
|
|
688
|
+
return result;
|
|
689
|
+
}
|
|
690
|
+
|
|
691
|
+
static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
|
|
692
|
+
wgpu::Buffer & buf,
|
|
693
|
+
uint32_t value,
|
|
694
|
+
size_t offset,
|
|
695
|
+
size_t size) {
|
|
696
|
+
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
|
|
697
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
698
|
+
{ .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
|
|
699
|
+
};
|
|
700
|
+
size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->memset_bytes_per_thread;
|
|
701
|
+
uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
|
|
702
|
+
|
|
703
|
+
webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipelines[0], params, entries, wg_x);
|
|
704
|
+
std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command }) };
|
|
705
|
+
ggml_backend_webgpu_wait(ctx, futures);
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
/** End WebGPU Actions */
|
|
709
|
+
|
|
710
|
+
/** GGML Backend Interface */
|
|
711
|
+
|
|
712
|
+
static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
|
|
713
|
+
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
|
|
714
|
+
return ctx->name.c_str();
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
// TODO: implement proper cleanup
|
|
718
|
+
static void ggml_backend_webgpu_free(ggml_backend_t backend) {
|
|
719
|
+
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
|
|
720
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
|
|
721
|
+
|
|
722
|
+
#ifdef GGML_WEBGPU_CPU_PROFILE
|
|
723
|
+
std::cout << "\n[ggml_webgpu cpu profiling summary]\n";
|
|
724
|
+
double total_cpu = 0.0;
|
|
725
|
+
for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
|
|
726
|
+
total_cpu += kv.second;
|
|
727
|
+
}
|
|
728
|
+
std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n";
|
|
729
|
+
std::cout << "ggml_webgpu: cpu breakdown:\n";
|
|
730
|
+
for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
|
|
731
|
+
double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
|
|
732
|
+
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
|
|
733
|
+
}
|
|
734
|
+
if (ctx->webgpu_ctx->cpu_detail_ms.size() > 0) {
|
|
735
|
+
std::cout << "ggml_webgpu: cpu detailed breakdown:\n";
|
|
736
|
+
}
|
|
737
|
+
for (const auto & kv : ctx->webgpu_ctx->cpu_detail_ms) {
|
|
738
|
+
double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
|
|
739
|
+
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
|
|
740
|
+
}
|
|
741
|
+
#endif
|
|
742
|
+
|
|
743
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
744
|
+
std::cout << "\n[ggml_webgpu gpu profiling summary]\n";
|
|
745
|
+
double total_gpu = 0.0;
|
|
746
|
+
for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
|
|
747
|
+
total_gpu += kv.second;
|
|
748
|
+
}
|
|
749
|
+
std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n";
|
|
750
|
+
std::cout << "\nggml_webgpu: gpu breakdown:\n";
|
|
751
|
+
for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
|
|
752
|
+
double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
|
|
753
|
+
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
|
|
754
|
+
}
|
|
755
|
+
#endif
|
|
756
|
+
|
|
757
|
+
#if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE)
|
|
758
|
+
std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
|
|
759
|
+
#endif
|
|
760
|
+
|
|
761
|
+
#if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE)
|
|
762
|
+
GGML_UNUSED(ctx);
|
|
763
|
+
#endif
|
|
764
|
+
}
|
|
765
|
+
|
|
766
|
+
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
|
|
767
|
+
return webgpu_tensor_offset(tensor) + tensor->view_offs;
|
|
768
|
+
}
|
|
769
|
+
|
|
770
|
+
static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
|
|
771
|
+
ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
|
|
772
|
+
return ctx->buffer;
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
|
|
776
|
+
size_t offset = ggml_webgpu_tensor_offset(t);
|
|
777
|
+
return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
|
778
|
+
}
|
|
779
|
+
|
|
780
|
+
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
|
|
781
|
+
size_t offset = ggml_webgpu_tensor_offset(t);
|
|
782
|
+
return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
|
783
|
+
}
|
|
784
|
+
|
|
785
|
+
static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
|
|
786
|
+
return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
|
787
|
+
}
|
|
788
|
+
|
|
789
|
+
// Used to determine if two tensors are the same for in-place operations
|
|
790
|
+
static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
|
|
791
|
+
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
|
|
792
|
+
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
|
|
793
|
+
}
|
|
794
|
+
|
|
795
|
+
static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
796
|
+
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
|
797
|
+
|
|
798
|
+
std::vector<uint32_t> params = {
|
|
799
|
+
ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
800
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
801
|
+
// Convert byte-strides to element-strides
|
|
802
|
+
(uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
|
803
|
+
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
|
804
|
+
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
805
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
806
|
+
// Logical shapes
|
|
807
|
+
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
|
|
808
|
+
(uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
|
|
809
|
+
};
|
|
810
|
+
|
|
811
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
812
|
+
{ .binding = 0,
|
|
813
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
814
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
815
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
|
816
|
+
{ .binding = 1,
|
|
817
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
818
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
819
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
|
820
|
+
};
|
|
821
|
+
|
|
822
|
+
uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
|
|
823
|
+
return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x);
|
|
824
|
+
}
|
|
825
|
+
|
|
826
|
+
static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
|
827
|
+
ggml_tensor * src,
|
|
828
|
+
ggml_tensor * idx,
|
|
829
|
+
ggml_tensor * dst) {
|
|
830
|
+
// For set rows specifically, we need to check if src and idx are empty tensors.
|
|
831
|
+
if (ggml_is_empty(src) || ggml_is_empty(idx)) {
|
|
832
|
+
return std::nullopt;
|
|
833
|
+
}
|
|
834
|
+
|
|
835
|
+
webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
|
|
836
|
+
if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
|
|
837
|
+
error_bufs.host_buf.Unmap();
|
|
838
|
+
}
|
|
839
|
+
|
|
840
|
+
std::vector<uint32_t> params = {
|
|
841
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
842
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
|
843
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
844
|
+
// Convert byte-strides to element-strides
|
|
845
|
+
(uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
|
846
|
+
(uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
|
|
847
|
+
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
|
|
848
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
849
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
850
|
+
// Shape of src
|
|
851
|
+
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
|
|
852
|
+
// Shape of idx
|
|
853
|
+
(uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
|
|
854
|
+
};
|
|
855
|
+
|
|
856
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
857
|
+
{ .binding = 0,
|
|
858
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
859
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
860
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
|
861
|
+
{ .binding = 1,
|
|
862
|
+
.buffer = ggml_webgpu_tensor_buf(idx),
|
|
863
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, idx),
|
|
864
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, idx) },
|
|
865
|
+
{ .binding = 2,
|
|
866
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
867
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
868
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
|
869
|
+
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
|
|
870
|
+
};
|
|
871
|
+
|
|
872
|
+
int vectorized = src->ne[0] % 4 == 0;
|
|
873
|
+
webgpu_pipeline pipeline = ctx->set_rows_pipelines[0][vectorized];
|
|
874
|
+
uint32_t threads;
|
|
875
|
+
if (vectorized) {
|
|
876
|
+
threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
|
|
877
|
+
} else {
|
|
878
|
+
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
|
|
879
|
+
}
|
|
880
|
+
|
|
881
|
+
uint32_t wg_x = CEIL_DIV(threads, WEBGPU_MAX_WG_SIZE);
|
|
882
|
+
|
|
883
|
+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs);
|
|
884
|
+
}
|
|
885
|
+
|
|
886
|
+
static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
|
887
|
+
ggml_tensor * src,
|
|
888
|
+
ggml_tensor * idx,
|
|
889
|
+
ggml_tensor * dst) {
|
|
890
|
+
std::vector<uint32_t> params = {
|
|
891
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
892
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
|
893
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
894
|
+
// Convert byte-strides to element-strides
|
|
895
|
+
(uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
|
896
|
+
(uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
|
|
897
|
+
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
|
|
898
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
899
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
900
|
+
// Shape of dst
|
|
901
|
+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
|
|
902
|
+
// Shape of idx
|
|
903
|
+
(uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
|
|
904
|
+
};
|
|
905
|
+
|
|
906
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
907
|
+
{ .binding = 0,
|
|
908
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
909
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
910
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
|
911
|
+
{ .binding = 1,
|
|
912
|
+
.buffer = ggml_webgpu_tensor_buf(idx),
|
|
913
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, idx),
|
|
914
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, idx) },
|
|
915
|
+
{ .binding = 2,
|
|
916
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
917
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
918
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
|
919
|
+
};
|
|
920
|
+
|
|
921
|
+
uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE);
|
|
922
|
+
|
|
923
|
+
uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
|
|
924
|
+
webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized];
|
|
925
|
+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
|
926
|
+
}
|
|
927
|
+
|
|
928
|
+
static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|
929
|
+
ggml_tensor * src0,
|
|
930
|
+
ggml_tensor * src1,
|
|
931
|
+
ggml_tensor * dst) {
|
|
932
|
+
std::vector<uint32_t> params = {
|
|
933
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
|
934
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
|
935
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
936
|
+
(uint32_t) dst->ne[0], // number of rows in result (M, transposed)
|
|
937
|
+
(uint32_t) dst->ne[1], // number of columns in result (N)
|
|
938
|
+
(uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
|
|
939
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
|
|
940
|
+
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
|
|
941
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
|
|
942
|
+
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
|
|
943
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
|
|
944
|
+
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
|
|
945
|
+
(uint32_t) src0->ne[2], // batch size in dimension 2
|
|
946
|
+
(uint32_t) src0->ne[3], // batch size in dimension 3
|
|
947
|
+
(uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
|
|
948
|
+
(uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
|
|
949
|
+
};
|
|
950
|
+
|
|
951
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
952
|
+
{ .binding = 0,
|
|
953
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
954
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
|
955
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
|
956
|
+
{ .binding = 1,
|
|
957
|
+
.buffer = ggml_webgpu_tensor_buf(src1),
|
|
958
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
|
959
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
|
960
|
+
{ .binding = 2,
|
|
961
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
962
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
963
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
|
964
|
+
};
|
|
965
|
+
|
|
966
|
+
webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0];
|
|
967
|
+
|
|
968
|
+
uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE);
|
|
969
|
+
uint32_t wg_y = 1;
|
|
970
|
+
|
|
971
|
+
bool use_fast = false;
|
|
972
|
+
switch (src1->type) {
|
|
973
|
+
case GGML_TYPE_F16:
|
|
974
|
+
use_fast = (src0->type == GGML_TYPE_F16);
|
|
975
|
+
break;
|
|
976
|
+
case GGML_TYPE_F32:
|
|
977
|
+
switch (src0->type) {
|
|
978
|
+
case GGML_TYPE_F32:
|
|
979
|
+
case GGML_TYPE_F16:
|
|
980
|
+
case GGML_TYPE_Q4_0:
|
|
981
|
+
use_fast = true;
|
|
982
|
+
break;
|
|
983
|
+
default:
|
|
984
|
+
break;
|
|
985
|
+
}
|
|
986
|
+
break;
|
|
987
|
+
default:
|
|
988
|
+
break;
|
|
989
|
+
}
|
|
990
|
+
|
|
991
|
+
if (use_fast) {
|
|
992
|
+
int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
|
|
993
|
+
if (dst->ne[1] == 1) {
|
|
994
|
+
// We don't support vectorized mul_mat_vec for quantized types
|
|
995
|
+
vectorized = vectorized && (src0->type < 2);
|
|
996
|
+
pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
|
|
997
|
+
uint32_t batches = dst->ne[2] * dst->ne[3];
|
|
998
|
+
uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
|
|
999
|
+
uint32_t total_wg = output_groups * batches;
|
|
1000
|
+
wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension;
|
|
1001
|
+
wg_y = CEIL_DIV(total_wg, ctx->limits.maxComputeWorkgroupsPerDimension);
|
|
1002
|
+
} else {
|
|
1003
|
+
pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
|
|
1004
|
+
uint32_t wg_m;
|
|
1005
|
+
uint32_t wg_n;
|
|
1006
|
+
#ifndef __EMSCRIPTEN__
|
|
1007
|
+
if (ctx->supports_subgroup_matrix) {
|
|
1008
|
+
// The total number of subgroups/workgroups needed per matrix.
|
|
1009
|
+
uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m;
|
|
1010
|
+
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
|
|
1011
|
+
uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n;
|
|
1012
|
+
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
|
|
1013
|
+
} else {
|
|
1014
|
+
#endif
|
|
1015
|
+
uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
|
|
1016
|
+
uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
|
|
1017
|
+
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
|
|
1018
|
+
wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
|
|
1019
|
+
#ifndef __EMSCRIPTEN__
|
|
1020
|
+
}
|
|
1021
|
+
#endif
|
|
1022
|
+
|
|
1023
|
+
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
|
1024
|
+
}
|
|
1025
|
+
}
|
|
1026
|
+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|
1030
|
+
ggml_tensor * Q,
|
|
1031
|
+
ggml_tensor * K,
|
|
1032
|
+
ggml_tensor * V,
|
|
1033
|
+
ggml_tensor * mask,
|
|
1034
|
+
ggml_tensor * sinks,
|
|
1035
|
+
ggml_tensor * dst) {
|
|
1036
|
+
float scale = *(float *) dst->op_params;
|
|
1037
|
+
float max_bias;
|
|
1038
|
+
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
1039
|
+
float logit_softcap;
|
|
1040
|
+
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
|
1041
|
+
if (logit_softcap != 0.0f) {
|
|
1042
|
+
scale /= logit_softcap;
|
|
1043
|
+
}
|
|
1044
|
+
float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
|
|
1045
|
+
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
|
1046
|
+
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
1047
|
+
|
|
1048
|
+
const int has_mask = (mask != nullptr);
|
|
1049
|
+
const int has_sinks = (sinks != nullptr);
|
|
1050
|
+
|
|
1051
|
+
std::vector<uint32_t> params = {
|
|
1052
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
|
|
1053
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
|
|
1054
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
|
|
1055
|
+
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
|
|
1056
|
+
has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
|
|
1057
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1058
|
+
(uint32_t) Q->ne[2], // number of heads
|
|
1059
|
+
(uint32_t) Q->ne[1], // sequence length (Q)
|
|
1060
|
+
(uint32_t) K->ne[1], // sequence length (K/V)
|
|
1061
|
+
(uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1
|
|
1062
|
+
(uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2
|
|
1063
|
+
(uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3
|
|
1064
|
+
(uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1
|
|
1065
|
+
(uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2
|
|
1066
|
+
(uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3
|
|
1067
|
+
(uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
|
|
1068
|
+
(uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
|
|
1069
|
+
(uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
|
|
1070
|
+
has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
|
|
1071
|
+
(uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
|
|
1072
|
+
*(uint32_t *) &scale, // scale (possibly adjusted for logit softcap)
|
|
1073
|
+
*(uint32_t *) &max_bias,
|
|
1074
|
+
*(uint32_t *) &logit_softcap,
|
|
1075
|
+
*(uint32_t *) &n_head_log2,
|
|
1076
|
+
*(uint32_t *) &m0,
|
|
1077
|
+
*(uint32_t *) &m1
|
|
1078
|
+
|
|
1079
|
+
};
|
|
1080
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1081
|
+
{ .binding = 0,
|
|
1082
|
+
.buffer = ggml_webgpu_tensor_buf(Q),
|
|
1083
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, Q),
|
|
1084
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, Q) },
|
|
1085
|
+
{ .binding = 1,
|
|
1086
|
+
.buffer = ggml_webgpu_tensor_buf(K),
|
|
1087
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, K),
|
|
1088
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, K) },
|
|
1089
|
+
{ .binding = 2,
|
|
1090
|
+
.buffer = ggml_webgpu_tensor_buf(V),
|
|
1091
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, V),
|
|
1092
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, V) }
|
|
1093
|
+
};
|
|
1094
|
+
uint32_t binding_index = 3;
|
|
1095
|
+
if (has_mask) {
|
|
1096
|
+
entries.push_back({ .binding = binding_index++,
|
|
1097
|
+
.buffer = ggml_webgpu_tensor_buf(mask),
|
|
1098
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
|
|
1099
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, mask) });
|
|
1100
|
+
}
|
|
1101
|
+
if (has_sinks) {
|
|
1102
|
+
entries.push_back({ .binding = binding_index++,
|
|
1103
|
+
.buffer = ggml_webgpu_tensor_buf(sinks),
|
|
1104
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
|
|
1105
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
|
|
1106
|
+
}
|
|
1107
|
+
entries.push_back({ .binding = binding_index++,
|
|
1108
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1109
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1110
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1111
|
+
|
|
1112
|
+
bool kv_direct =
|
|
1113
|
+
(K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
|
1114
|
+
|
|
1115
|
+
flash_attn_pipeline_key key = {
|
|
1116
|
+
.q_type = Q->type,
|
|
1117
|
+
.kv_type = K->type,
|
|
1118
|
+
.dst_type = dst->type,
|
|
1119
|
+
.head_dim_qk = (uint32_t) Q->ne[0],
|
|
1120
|
+
.head_dim_v = (uint32_t) V->ne[0],
|
|
1121
|
+
.kv_direct = kv_direct,
|
|
1122
|
+
.has_mask = static_cast<bool>(has_mask),
|
|
1123
|
+
.has_sinks = static_cast<bool>(has_sinks),
|
|
1124
|
+
.uses_logit_softcap = logit_softcap != 0.0f,
|
|
1125
|
+
};
|
|
1126
|
+
|
|
1127
|
+
webgpu_pipeline pipeline;
|
|
1128
|
+
ggml_webgpu_flash_attn_shader_decisions decisions = {};
|
|
1129
|
+
|
|
1130
|
+
auto it = ctx->flash_attn_pipelines.find(key);
|
|
1131
|
+
if (it != ctx->flash_attn_pipelines.end()) {
|
|
1132
|
+
pipeline = it->second;
|
|
1133
|
+
decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
|
|
1134
|
+
} else {
|
|
1135
|
+
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
|
1136
|
+
it = ctx->flash_attn_pipelines.find(key);
|
|
1137
|
+
if (it != ctx->flash_attn_pipelines.end()) {
|
|
1138
|
+
pipeline = it->second;
|
|
1139
|
+
decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
|
|
1140
|
+
} else {
|
|
1141
|
+
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type,
|
|
1142
|
+
.head_dim_qk = (uint32_t) Q->ne[0],
|
|
1143
|
+
.head_dim_v = (uint32_t) V->ne[0],
|
|
1144
|
+
.kv_direct = kv_direct,
|
|
1145
|
+
.has_mask = static_cast<bool>(has_mask),
|
|
1146
|
+
.has_sinks = static_cast<bool>(has_sinks),
|
|
1147
|
+
.uses_logit_softcap = logit_softcap != 0.0f,
|
|
1148
|
+
.sg_mat_m = ctx->sg_mat_m,
|
|
1149
|
+
.sg_mat_n = ctx->sg_mat_n,
|
|
1150
|
+
.sg_mat_k = ctx->sg_mat_k,
|
|
1151
|
+
.wg_mem_limit_bytes =
|
|
1152
|
+
ctx->limits.maxComputeWorkgroupStorageSize,
|
|
1153
|
+
.max_subgroup_size = ctx->max_subgroup_size };
|
|
1154
|
+
|
|
1155
|
+
ggml_webgpu_processed_shader processed =
|
|
1156
|
+
ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
|
|
1157
|
+
pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
|
1158
|
+
pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions);
|
|
1159
|
+
ctx->flash_attn_pipelines.emplace(key, pipeline);
|
|
1160
|
+
decisions = processed.decisions;
|
|
1161
|
+
}
|
|
1162
|
+
}
|
|
1163
|
+
|
|
1164
|
+
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
|
|
1165
|
+
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
|
1166
|
+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
|
1167
|
+
}
|
|
1168
|
+
|
|
1169
|
+
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
1170
|
+
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
|
1171
|
+
ggml_unary_op unary_op = ggml_get_unary_op(dst);
|
|
1172
|
+
uint32_t inplace = ggml_webgpu_tensor_equal(src, dst);
|
|
1173
|
+
|
|
1174
|
+
std::vector<uint32_t> params = {
|
|
1175
|
+
ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
1176
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1177
|
+
// Convert byte-strides to element-strides
|
|
1178
|
+
(uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
|
1179
|
+
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
|
1180
|
+
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
1181
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
1182
|
+
// Logical shapes
|
|
1183
|
+
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
|
|
1184
|
+
(uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
|
|
1185
|
+
};
|
|
1186
|
+
|
|
1187
|
+
switch (unary_op) {
|
|
1188
|
+
case GGML_UNARY_OP_XIELU:
|
|
1189
|
+
{
|
|
1190
|
+
// Get float parameters and reinterpret their bit patterns as uint32_t
|
|
1191
|
+
// for passing through the params buffer
|
|
1192
|
+
float alpha_n = ggml_get_op_params_f32(dst, 1);
|
|
1193
|
+
float alpha_p = ggml_get_op_params_f32(dst, 2);
|
|
1194
|
+
float beta = ggml_get_op_params_f32(dst, 3);
|
|
1195
|
+
float eps = ggml_get_op_params_f32(dst, 4);
|
|
1196
|
+
params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
|
|
1197
|
+
params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
|
|
1198
|
+
params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
|
|
1199
|
+
params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
|
|
1200
|
+
break;
|
|
1201
|
+
}
|
|
1202
|
+
default:
|
|
1203
|
+
break;
|
|
1204
|
+
}
|
|
1205
|
+
|
|
1206
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1207
|
+
{ .binding = 0,
|
|
1208
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
1209
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
1210
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
|
1211
|
+
};
|
|
1212
|
+
if (!inplace) {
|
|
1213
|
+
entries.push_back({ .binding = 1,
|
|
1214
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1215
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1216
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1217
|
+
}
|
|
1218
|
+
|
|
1219
|
+
uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
|
|
1220
|
+
return ggml_backend_webgpu_build(ctx, ctx->unary_pipelines[unary_op][dst->type][inplace], params, entries, wg_x);
|
|
1221
|
+
}
|
|
1222
|
+
|
|
1223
|
+
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
|
1224
|
+
ggml_tensor * src0,
|
|
1225
|
+
ggml_tensor * src1,
|
|
1226
|
+
ggml_tensor * dst,
|
|
1227
|
+
webgpu_pipeline & pipeline,
|
|
1228
|
+
bool inplace) {
|
|
1229
|
+
std::vector<uint32_t> params = {
|
|
1230
|
+
(uint32_t) ggml_nelements(dst),
|
|
1231
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
|
1232
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
|
1233
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1234
|
+
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
|
1235
|
+
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
|
1236
|
+
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
|
1237
|
+
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
|
1238
|
+
(uint32_t) src0->ne[0],
|
|
1239
|
+
(uint32_t) src0->ne[1],
|
|
1240
|
+
(uint32_t) src0->ne[2],
|
|
1241
|
+
(uint32_t) src1->ne[0],
|
|
1242
|
+
(uint32_t) src1->ne[1],
|
|
1243
|
+
(uint32_t) src1->ne[2],
|
|
1244
|
+
(uint32_t) src1->ne[3],
|
|
1245
|
+
};
|
|
1246
|
+
|
|
1247
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1248
|
+
{ .binding = 0,
|
|
1249
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
1250
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
|
1251
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
|
1252
|
+
{ .binding = 1,
|
|
1253
|
+
.buffer = ggml_webgpu_tensor_buf(src1),
|
|
1254
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
|
1255
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
|
|
1256
|
+
};
|
|
1257
|
+
if (!inplace) {
|
|
1258
|
+
entries.push_back({ .binding = 2,
|
|
1259
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1260
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1261
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1262
|
+
}
|
|
1263
|
+
|
|
1264
|
+
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
|
|
1265
|
+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
|
1266
|
+
}
|
|
1267
|
+
|
|
1268
|
+
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
1269
|
+
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
|
1270
|
+
|
|
1271
|
+
std::vector<uint32_t> params = {
|
|
1272
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
1273
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1274
|
+
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
|
1275
|
+
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
|
1276
|
+
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
|
1277
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
1278
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
1279
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
1280
|
+
(uint32_t) src->ne[0],
|
|
1281
|
+
(uint32_t) src->ne[1],
|
|
1282
|
+
(uint32_t) src->ne[2],
|
|
1283
|
+
(uint32_t) src->ne[3],
|
|
1284
|
+
*(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
|
|
1285
|
+
};
|
|
1286
|
+
|
|
1287
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1288
|
+
{ .binding = 0,
|
|
1289
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
1290
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
1291
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
|
|
1292
|
+
};
|
|
1293
|
+
if (!inplace) {
|
|
1294
|
+
entries.push_back({ .binding = 1,
|
|
1295
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1296
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1297
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1298
|
+
}
|
|
1299
|
+
|
|
1300
|
+
return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipelines[inplace], params, entries, ggml_nrows(src));
|
|
1301
|
+
}
|
|
1302
|
+
|
|
1303
|
+
static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
|
|
1304
|
+
ggml_tensor * src0,
|
|
1305
|
+
ggml_tensor * src1,
|
|
1306
|
+
ggml_tensor * src2,
|
|
1307
|
+
ggml_tensor * dst) {
|
|
1308
|
+
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
|
|
1309
|
+
const int has_freq_factor = (src2 != nullptr);
|
|
1310
|
+
|
|
1311
|
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
1312
|
+
const int mode = ((int32_t *) dst->op_params)[2];
|
|
1313
|
+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
1314
|
+
|
|
1315
|
+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
1316
|
+
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
1317
|
+
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
1318
|
+
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
1319
|
+
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
1320
|
+
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
1321
|
+
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
1322
|
+
|
|
1323
|
+
int sections[4];
|
|
1324
|
+
memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
|
|
1325
|
+
|
|
1326
|
+
float theta_scale = powf(freq_base, -2.0f / n_dims);
|
|
1327
|
+
|
|
1328
|
+
float corr_dims[2];
|
|
1329
|
+
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
1330
|
+
|
|
1331
|
+
std::vector<uint32_t> params = {
|
|
1332
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
|
1333
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
|
1334
|
+
src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
|
|
1335
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1336
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
|
1337
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
|
1338
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
|
1339
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
1340
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
1341
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
1342
|
+
(uint32_t) ggml_nelements(src0) / 2,
|
|
1343
|
+
(uint32_t) src0->ne[0],
|
|
1344
|
+
(uint32_t) src0->ne[1],
|
|
1345
|
+
(uint32_t) src0->ne[2],
|
|
1346
|
+
(uint32_t) n_dims,
|
|
1347
|
+
(uint32_t) mode,
|
|
1348
|
+
*(uint32_t *) &theta_scale,
|
|
1349
|
+
*(uint32_t *) &attn_factor,
|
|
1350
|
+
*(uint32_t *) &freq_scale,
|
|
1351
|
+
*(uint32_t *) &ext_factor,
|
|
1352
|
+
*(uint32_t *) &corr_dims[0],
|
|
1353
|
+
*(uint32_t *) &corr_dims[1],
|
|
1354
|
+
(uint32_t) sections[0],
|
|
1355
|
+
(uint32_t) sections[1],
|
|
1356
|
+
(uint32_t) sections[2],
|
|
1357
|
+
(uint32_t) sections[3]
|
|
1358
|
+
};
|
|
1359
|
+
|
|
1360
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1361
|
+
{ .binding = 0,
|
|
1362
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
1363
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
|
1364
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
|
1365
|
+
{ .binding = 1,
|
|
1366
|
+
.buffer = ggml_webgpu_tensor_buf(src1),
|
|
1367
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
|
1368
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
|
|
1369
|
+
};
|
|
1370
|
+
uint32_t dst_binding = 2;
|
|
1371
|
+
if (has_freq_factor) {
|
|
1372
|
+
dst_binding = 3;
|
|
1373
|
+
entries.push_back({ .binding = 2,
|
|
1374
|
+
.buffer = ggml_webgpu_tensor_buf(src2),
|
|
1375
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
|
|
1376
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src2) });
|
|
1377
|
+
}
|
|
1378
|
+
if (!inplace) {
|
|
1379
|
+
entries.push_back({ .binding = dst_binding,
|
|
1380
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1381
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1382
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1383
|
+
}
|
|
1384
|
+
|
|
1385
|
+
webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
|
|
1386
|
+
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
|
|
1387
|
+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
|
1388
|
+
}
|
|
1389
|
+
|
|
1390
|
+
static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
|
1391
|
+
const int split = (src1 != nullptr);
|
|
1392
|
+
|
|
1393
|
+
std::vector<uint32_t> params = {
|
|
1394
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
|
1395
|
+
src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
|
|
1396
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1397
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
|
1398
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
|
1399
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
|
1400
|
+
src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
|
|
1401
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
|
1402
|
+
src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
|
|
1403
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
|
1404
|
+
src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
|
|
1405
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
|
1406
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
1407
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
1408
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
1409
|
+
(uint32_t) ggml_nelements(dst),
|
|
1410
|
+
(uint32_t) dst->ne[0],
|
|
1411
|
+
(uint32_t) dst->ne[1],
|
|
1412
|
+
(uint32_t) dst->ne[2],
|
|
1413
|
+
(uint32_t) ((int32_t *) dst->op_params)[1], // swapped
|
|
1414
|
+
*(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
|
|
1415
|
+
*(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
|
|
1416
|
+
};
|
|
1417
|
+
|
|
1418
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1419
|
+
{ .binding = 0,
|
|
1420
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
1421
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
|
1422
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
|
1423
|
+
};
|
|
1424
|
+
uint32_t dst_binding = 1;
|
|
1425
|
+
if (split) {
|
|
1426
|
+
dst_binding = 2;
|
|
1427
|
+
entries.push_back({ .binding = 1,
|
|
1428
|
+
.buffer = ggml_webgpu_tensor_buf(src1),
|
|
1429
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
|
1430
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
|
|
1431
|
+
}
|
|
1432
|
+
entries.push_back({ .binding = dst_binding,
|
|
1433
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1434
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1435
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1436
|
+
|
|
1437
|
+
webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
|
|
1438
|
+
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
|
|
1439
|
+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
|
1440
|
+
}
|
|
1441
|
+
|
|
1442
|
+
static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
|
1443
|
+
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
|
1444
|
+
|
|
1445
|
+
std::vector<uint32_t> params = {
|
|
1446
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
|
1447
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1448
|
+
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
|
1449
|
+
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
|
1450
|
+
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
|
1451
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
1452
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
1453
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
1454
|
+
(uint32_t) ggml_nelements(dst),
|
|
1455
|
+
(uint32_t) src->ne[0],
|
|
1456
|
+
(uint32_t) src->ne[1],
|
|
1457
|
+
(uint32_t) src->ne[2],
|
|
1458
|
+
*(uint32_t *) dst->op_params, // scale
|
|
1459
|
+
*(uint32_t *) &dst->op_params[1] // bias
|
|
1460
|
+
};
|
|
1461
|
+
|
|
1462
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1463
|
+
{ .binding = 0,
|
|
1464
|
+
.buffer = ggml_webgpu_tensor_buf(src),
|
|
1465
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
|
1466
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
|
|
1467
|
+
};
|
|
1468
|
+
if (!inplace) {
|
|
1469
|
+
entries.push_back({ .binding = 1,
|
|
1470
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1471
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1472
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1473
|
+
}
|
|
1474
|
+
|
|
1475
|
+
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
|
|
1476
|
+
return ggml_backend_webgpu_build(ctx, ctx->scale_pipelines[inplace], params, entries, wg_x);
|
|
1477
|
+
}
|
|
1478
|
+
|
|
1479
|
+
static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
|
|
1480
|
+
ggml_tensor * src0,
|
|
1481
|
+
ggml_tensor * src1,
|
|
1482
|
+
ggml_tensor * src2,
|
|
1483
|
+
ggml_tensor * dst) {
|
|
1484
|
+
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
|
|
1485
|
+
const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
|
|
1486
|
+
const int has_sink = (src2 != nullptr);
|
|
1487
|
+
float max_bias;
|
|
1488
|
+
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
1489
|
+
float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
|
|
1490
|
+
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
|
1491
|
+
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
1492
|
+
|
|
1493
|
+
std::vector<uint32_t> params = {
|
|
1494
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
|
1495
|
+
mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
|
|
1496
|
+
has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
|
|
1497
|
+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
|
1498
|
+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
|
1499
|
+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
|
1500
|
+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
|
1501
|
+
mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
|
|
1502
|
+
mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
|
|
1503
|
+
mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
|
|
1504
|
+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
|
1505
|
+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
|
1506
|
+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
|
1507
|
+
(uint32_t) ggml_nelements(dst),
|
|
1508
|
+
(uint32_t) src0->ne[0],
|
|
1509
|
+
(uint32_t) src0->ne[1],
|
|
1510
|
+
(uint32_t) src0->ne[2],
|
|
1511
|
+
mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
|
|
1512
|
+
mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
|
|
1513
|
+
*(uint32_t *) dst->op_params, // scale
|
|
1514
|
+
*(uint32_t *) &max_bias,
|
|
1515
|
+
*(uint32_t *) &n_head_log2,
|
|
1516
|
+
*(uint32_t *) &m0,
|
|
1517
|
+
*(uint32_t *) &m1
|
|
1518
|
+
};
|
|
1519
|
+
|
|
1520
|
+
std::vector<wgpu::BindGroupEntry> entries = {
|
|
1521
|
+
{ .binding = 0,
|
|
1522
|
+
.buffer = ggml_webgpu_tensor_buf(src0),
|
|
1523
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
|
1524
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) }
|
|
1525
|
+
};
|
|
1526
|
+
uint32_t binding_num = 1;
|
|
1527
|
+
if (mask_type < 2) {
|
|
1528
|
+
entries.push_back({ .binding = binding_num,
|
|
1529
|
+
.buffer = ggml_webgpu_tensor_buf(src1),
|
|
1530
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
|
1531
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
|
|
1532
|
+
binding_num++;
|
|
1533
|
+
}
|
|
1534
|
+
if (has_sink) {
|
|
1535
|
+
entries.push_back({ .binding = binding_num,
|
|
1536
|
+
.buffer = ggml_webgpu_tensor_buf(src2),
|
|
1537
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
|
|
1538
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, src2) });
|
|
1539
|
+
binding_num++;
|
|
1540
|
+
}
|
|
1541
|
+
if (!inplace) {
|
|
1542
|
+
entries.push_back({ .binding = binding_num,
|
|
1543
|
+
.buffer = ggml_webgpu_tensor_buf(dst),
|
|
1544
|
+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
|
1545
|
+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
|
1546
|
+
}
|
|
1547
|
+
|
|
1548
|
+
return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
|
|
1549
|
+
ggml_nrows(dst));
|
|
1550
|
+
}
|
|
1551
|
+
|
|
1552
|
+
// Returns the encoded command, or std::nullopt if the operation is a no-op
|
|
1553
|
+
static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
|
1554
|
+
if (ggml_is_empty(node)) {
|
|
1555
|
+
return std::nullopt;
|
|
1556
|
+
}
|
|
1557
|
+
WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
|
|
1558
|
+
|
|
1559
|
+
ggml_tensor * src0 = node->src[0];
|
|
1560
|
+
ggml_tensor * src1 = node->src[1];
|
|
1561
|
+
ggml_tensor * src2 = node->src[2];
|
|
1562
|
+
|
|
1563
|
+
switch (node->op) {
|
|
1564
|
+
// no-ops
|
|
1565
|
+
case GGML_OP_NONE:
|
|
1566
|
+
case GGML_OP_VIEW:
|
|
1567
|
+
case GGML_OP_PERMUTE:
|
|
1568
|
+
case GGML_OP_TRANSPOSE:
|
|
1569
|
+
case GGML_OP_RESHAPE:
|
|
1570
|
+
return std::nullopt;
|
|
1571
|
+
case GGML_OP_CPY:
|
|
1572
|
+
case GGML_OP_CONT:
|
|
1573
|
+
return ggml_webgpu_cpy(ctx, src0, node);
|
|
1574
|
+
case GGML_OP_SET_ROWS:
|
|
1575
|
+
return ggml_webgpu_set_rows(ctx, src0, src1, node);
|
|
1576
|
+
case GGML_OP_GET_ROWS:
|
|
1577
|
+
return ggml_webgpu_get_rows(ctx, src0, src1, node);
|
|
1578
|
+
case GGML_OP_MUL_MAT:
|
|
1579
|
+
return ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
|
1580
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
|
1581
|
+
return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
|
|
1582
|
+
case GGML_OP_ADD:
|
|
1583
|
+
{
|
|
1584
|
+
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
|
1585
|
+
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace);
|
|
1586
|
+
}
|
|
1587
|
+
case GGML_OP_SUB:
|
|
1588
|
+
{
|
|
1589
|
+
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
|
1590
|
+
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace);
|
|
1591
|
+
}
|
|
1592
|
+
case GGML_OP_MUL:
|
|
1593
|
+
{
|
|
1594
|
+
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
|
1595
|
+
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace);
|
|
1596
|
+
}
|
|
1597
|
+
case GGML_OP_DIV:
|
|
1598
|
+
{
|
|
1599
|
+
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
|
1600
|
+
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace);
|
|
1601
|
+
}
|
|
1602
|
+
case GGML_OP_RMS_NORM:
|
|
1603
|
+
return ggml_webgpu_rms_norm(ctx, src0, node);
|
|
1604
|
+
case GGML_OP_ROPE:
|
|
1605
|
+
return ggml_webgpu_rope(ctx, src0, src1, src2, node);
|
|
1606
|
+
case GGML_OP_GLU:
|
|
1607
|
+
return ggml_webgpu_glu(ctx, src0, src1, node);
|
|
1608
|
+
case GGML_OP_SCALE:
|
|
1609
|
+
return ggml_webgpu_scale(ctx, src0, node);
|
|
1610
|
+
case GGML_OP_SOFT_MAX:
|
|
1611
|
+
return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
|
|
1612
|
+
case GGML_OP_UNARY:
|
|
1613
|
+
return ggml_webgpu_unary_op(ctx, src0, node);
|
|
1614
|
+
default:
|
|
1615
|
+
return std::nullopt;
|
|
1616
|
+
}
|
|
1617
|
+
}
|
|
1618
|
+
|
|
1619
|
+
static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
|
1620
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
|
|
1621
|
+
|
|
1622
|
+
ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
|
|
1623
|
+
webgpu_context ctx = backend_ctx->webgpu_ctx;
|
|
1624
|
+
|
|
1625
|
+
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
|
|
1626
|
+
|
|
1627
|
+
ctx->inflight_threads++;
|
|
1628
|
+
|
|
1629
|
+
std::vector<webgpu_command> commands;
|
|
1630
|
+
std::vector<webgpu_submission_futures> futures;
|
|
1631
|
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
1632
|
+
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
|
|
1633
|
+
commands.push_back(*cmd);
|
|
1634
|
+
}
|
|
1635
|
+
// compute the batch size based on the number of inflight threads
|
|
1636
|
+
uint32_t inflight_threads = ctx->inflight_threads;
|
|
1637
|
+
uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
|
|
1638
|
+
WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
|
|
1639
|
+
if (commands.size() >= batch_size) {
|
|
1640
|
+
futures.push_back(ggml_backend_webgpu_submit(ctx, commands));
|
|
1641
|
+
// Process events and check for completed submissions
|
|
1642
|
+
ctx->instance.ProcessEvents();
|
|
1643
|
+
ggml_backend_webgpu_wait(ctx, futures, false);
|
|
1644
|
+
commands.clear();
|
|
1645
|
+
}
|
|
1646
|
+
}
|
|
1647
|
+
if (!commands.empty()) {
|
|
1648
|
+
webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
|
|
1649
|
+
futures.push_back(new_futures);
|
|
1650
|
+
}
|
|
1651
|
+
|
|
1652
|
+
ggml_backend_webgpu_wait(ctx, futures);
|
|
1653
|
+
ctx->inflight_threads--;
|
|
1654
|
+
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
|
|
1655
|
+
return GGML_STATUS_SUCCESS;
|
|
1656
|
+
}
|
|
1657
|
+
|
|
1658
|
+
static ggml_backend_i ggml_backend_webgpu_i = {
|
|
1659
|
+
/* .get_name = */ ggml_backend_webgpu_name,
|
|
1660
|
+
/* .free = */ ggml_backend_webgpu_free,
|
|
1661
|
+
/* .set_tensor_async = */ NULL,
|
|
1662
|
+
/* .get_tensor_async = */ NULL,
|
|
1663
|
+
/* .cpy_tensor_async = */ NULL,
|
|
1664
|
+
/* .synchronize = */ NULL,
|
|
1665
|
+
/* .graph_plan_create = */ NULL,
|
|
1666
|
+
/* .graph_plan_free = */ NULL,
|
|
1667
|
+
/* .graph_plan_update = */ NULL,
|
|
1668
|
+
/* .graph_plan_compute = */ NULL,
|
|
1669
|
+
/* .graph_compute = */ ggml_backend_webgpu_graph_compute,
|
|
1670
|
+
/* .event_record = */ NULL,
|
|
1671
|
+
/* .event_wait = */ NULL,
|
|
1672
|
+
/* .graph_optimize = */ NULL,
|
|
1673
|
+
};
|
|
1674
|
+
|
|
1675
|
+
/* End GGML Backend Interface */
|
|
1676
|
+
|
|
1677
|
+
/* GGML Backend Buffer Interface */
|
|
1678
|
+
|
|
1679
|
+
static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
1680
|
+
ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
|
|
1681
|
+
ctx->buffer.Destroy();
|
|
1682
|
+
}
|
|
1683
|
+
|
|
1684
|
+
// Returns the "fake" base pointer.
|
|
1685
|
+
static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
1686
|
+
GGML_UNUSED(buffer);
|
|
1687
|
+
return webgpu_ptr_base;
|
|
1688
|
+
}
|
|
1689
|
+
|
|
1690
|
+
static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
|
|
1691
|
+
ggml_tensor * tensor,
|
|
1692
|
+
uint8_t value,
|
|
1693
|
+
size_t offset,
|
|
1694
|
+
size_t size) {
|
|
1695
|
+
if (size == 0) {
|
|
1696
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
|
|
1697
|
+
return;
|
|
1698
|
+
}
|
|
1699
|
+
|
|
1700
|
+
WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
|
|
1701
|
+
|
|
1702
|
+
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
|
1703
|
+
|
|
1704
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
|
|
1705
|
+
<< ", " << offset << ", " << size << ")");
|
|
1706
|
+
|
|
1707
|
+
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
1708
|
+
|
|
1709
|
+
// This is a trick to set all bytes of a u32 to the same 1 byte value.
|
|
1710
|
+
uint32_t val32 = (uint32_t) value * 0x01010101;
|
|
1711
|
+
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
|
|
1712
|
+
WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx);
|
|
1713
|
+
}
|
|
1714
|
+
|
|
1715
|
+
static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|
1716
|
+
ggml_tensor * tensor,
|
|
1717
|
+
const void * data,
|
|
1718
|
+
size_t offset,
|
|
1719
|
+
size_t size) {
|
|
1720
|
+
WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
|
|
1721
|
+
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
|
1722
|
+
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
|
|
1723
|
+
|
|
1724
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
|
|
1725
|
+
<< ", " << offset << ", " << size << ")");
|
|
1726
|
+
|
|
1727
|
+
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
1728
|
+
|
|
1729
|
+
webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
|
|
1730
|
+
|
|
1731
|
+
if (size % 4 != 0) {
|
|
1732
|
+
// If size is not a multiple of 4, we need to memset the remaining bytes
|
|
1733
|
+
size_t remaining_size = size % 4;
|
|
1734
|
+
|
|
1735
|
+
// pack the remaining bytes into a uint32_t
|
|
1736
|
+
uint32_t val32 = 0;
|
|
1737
|
+
|
|
1738
|
+
for (size_t i = 0; i < remaining_size; i++) {
|
|
1739
|
+
((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
|
|
1740
|
+
}
|
|
1741
|
+
// memset the remaining bytes
|
|
1742
|
+
ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size),
|
|
1743
|
+
remaining_size);
|
|
1744
|
+
} else {
|
|
1745
|
+
// wait for WriteBuffer to complete
|
|
1746
|
+
webgpu_ctx->instance.WaitAny(
|
|
1747
|
+
webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
|
|
1748
|
+
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
|
1749
|
+
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
|
1750
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
|
|
1751
|
+
std::string(message).c_str());
|
|
1752
|
+
}
|
|
1753
|
+
}),
|
|
1754
|
+
UINT64_MAX);
|
|
1755
|
+
}
|
|
1756
|
+
WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, webgpu_ctx);
|
|
1757
|
+
}
|
|
1758
|
+
|
|
1759
|
+
static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|
1760
|
+
const ggml_tensor * tensor,
|
|
1761
|
+
void * data,
|
|
1762
|
+
size_t offset,
|
|
1763
|
+
size_t size) {
|
|
1764
|
+
WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
|
|
1765
|
+
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
|
1766
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
|
|
1767
|
+
<< ", " << offset << ", " << size << ")");
|
|
1768
|
+
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
|
|
1769
|
+
wgpu::Device device = webgpu_ctx->device;
|
|
1770
|
+
|
|
1771
|
+
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
1772
|
+
|
|
1773
|
+
size_t final_size = size;
|
|
1774
|
+
if (size % 4 != 0) {
|
|
1775
|
+
// If size is not a multiple of 4, we need to round it up to the next multiple of 4
|
|
1776
|
+
final_size = size + (4 - (size % 4));
|
|
1777
|
+
}
|
|
1778
|
+
|
|
1779
|
+
std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
|
|
1780
|
+
|
|
1781
|
+
if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
|
|
1782
|
+
// Create a new staging buffer if it doesn't exist or is too small
|
|
1783
|
+
if (webgpu_ctx->get_tensor_staging_buf) {
|
|
1784
|
+
webgpu_ctx->get_tensor_staging_buf.Destroy();
|
|
1785
|
+
}
|
|
1786
|
+
ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
|
|
1787
|
+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
|
|
1788
|
+
}
|
|
1789
|
+
|
|
1790
|
+
// Copy the data from the buffer to the staging buffer
|
|
1791
|
+
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
|
1792
|
+
encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
|
|
1793
|
+
wgpu::CommandBuffer commands = encoder.Finish();
|
|
1794
|
+
|
|
1795
|
+
// Submit the command buffer to the queue
|
|
1796
|
+
webgpu_ctx->queue.Submit(1, &commands);
|
|
1797
|
+
|
|
1798
|
+
// Map the staging buffer to read the data
|
|
1799
|
+
ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
|
|
1800
|
+
// Must specify size here since the staging buffer might be larger than the tensor size
|
|
1801
|
+
const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
|
|
1802
|
+
|
|
1803
|
+
// Copy the data from the mapped range to the output buffer
|
|
1804
|
+
std::memcpy(data, mapped_range, size);
|
|
1805
|
+
webgpu_ctx->get_tensor_staging_buf.Unmap();
|
|
1806
|
+
WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx);
|
|
1807
|
+
}
|
|
1808
|
+
|
|
1809
|
+
static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
|
1810
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
|
|
1811
|
+
WEBGPU_CPU_PROFILE_TOTAL_START(clear);
|
|
1812
|
+
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
|
|
1813
|
+
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
|
|
1814
|
+
WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx);
|
|
1815
|
+
}
|
|
1816
|
+
|
|
1817
|
+
static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
|
|
1818
|
+
/* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer,
|
|
1819
|
+
/* .get_base = */ ggml_backend_webgpu_buffer_get_base,
|
|
1820
|
+
/* .init_tensor = */ NULL, // TODO: optional, needed?
|
|
1821
|
+
/* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor,
|
|
1822
|
+
/* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor,
|
|
1823
|
+
/* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
|
|
1824
|
+
/* .cpy_tensor = */ NULL, // TODO: optional, implement this
|
|
1825
|
+
/* .clear = */ ggml_backend_webgpu_buffer_clear,
|
|
1826
|
+
/* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
|
|
1827
|
+
};
|
|
1828
|
+
|
|
1829
|
+
/* End GGML Backend Buffer Interface */
|
|
1830
|
+
|
|
1831
|
+
/* GGML Backend Buffer Type Interface */
|
|
1832
|
+
|
|
1833
|
+
static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
|
1834
|
+
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
|
1835
|
+
return ctx->device_name.c_str();
|
|
1836
|
+
}
|
|
1837
|
+
|
|
1838
|
+
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
|
1839
|
+
size_t size) {
|
|
1840
|
+
static std::atomic<int> buffer_count;
|
|
1841
|
+
int buffer_id = buffer_count++;
|
|
1842
|
+
std::string buf_name = "tensor_buf" + std::to_string(buffer_id);
|
|
1843
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
|
|
1844
|
+
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
|
1845
|
+
|
|
1846
|
+
wgpu::Buffer buf;
|
|
1847
|
+
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
|
|
1848
|
+
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
|
|
1849
|
+
buf_name.c_str());
|
|
1850
|
+
|
|
1851
|
+
ggml_backend_webgpu_buffer_context * buf_ctx =
|
|
1852
|
+
new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name);
|
|
1853
|
+
|
|
1854
|
+
return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
|
|
1855
|
+
}
|
|
1856
|
+
|
|
1857
|
+
static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
1858
|
+
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
|
1859
|
+
return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment;
|
|
1860
|
+
}
|
|
1861
|
+
|
|
1862
|
+
// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
|
|
1863
|
+
static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
|
1864
|
+
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
|
1865
|
+
return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
|
|
1866
|
+
}
|
|
1867
|
+
|
|
1868
|
+
/* End GGML Backend Buffer Type Interface */
|
|
1869
|
+
|
|
1870
|
+
/* GGML Backend Device Interface */
|
|
1871
|
+
|
|
1872
|
+
static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) {
|
|
1873
|
+
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
|
1874
|
+
return ctx->device_name.c_str();
|
|
1875
|
+
}
|
|
1876
|
+
|
|
1877
|
+
static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) {
|
|
1878
|
+
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
|
1879
|
+
return ctx->device_desc.c_str();
|
|
1880
|
+
}
|
|
1881
|
+
|
|
1882
|
+
static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
1883
|
+
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
|
1884
|
+
// TODO: for now, return maxBufferSize as both free and total memory
|
|
1885
|
+
// Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
|
|
1886
|
+
uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize;
|
|
1887
|
+
// If we're on a 32-bit system, clamp to UINTPTR_MAX
|
|
1888
|
+
#if UINTPTR_MAX < UINT64_MAX
|
|
1889
|
+
uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);
|
|
1890
|
+
if (max_buffer_size > max_ptr_size) {
|
|
1891
|
+
max_buffer_size = max_ptr_size;
|
|
1892
|
+
}
|
|
1893
|
+
#endif
|
|
1894
|
+
*free = static_cast<size_t>(max_buffer_size);
|
|
1895
|
+
*total = static_cast<size_t>(max_buffer_size);
|
|
1896
|
+
}
|
|
1897
|
+
|
|
1898
|
+
static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
|
|
1899
|
+
GGML_UNUSED(dev);
|
|
1900
|
+
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
|
1901
|
+
}
|
|
1902
|
+
|
|
1903
|
+
static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
|
1904
|
+
props->name = ggml_backend_webgpu_device_get_name(dev);
|
|
1905
|
+
props->description = ggml_backend_webgpu_device_get_description(dev);
|
|
1906
|
+
props->type = ggml_backend_webgpu_device_get_type(dev);
|
|
1907
|
+
ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
|
1908
|
+
props->caps = {
|
|
1909
|
+
/* .async = */ false,
|
|
1910
|
+
/* .host_buffer = */ false,
|
|
1911
|
+
/* .buffer_from_host_ptr = */ false,
|
|
1912
|
+
/* .events = */ false,
|
|
1913
|
+
};
|
|
1914
|
+
}
|
|
1915
|
+
|
|
1916
|
+
static ggml_guid_t ggml_backend_webgpu_guid(void) {
|
|
1917
|
+
static const char * guid_str = "__ggml_webgpu :)";
|
|
1918
|
+
return reinterpret_cast<ggml_guid_t>((void *) guid_str);
|
|
1919
|
+
}
|
|
1920
|
+
|
|
1921
|
+
// Workgroup size is a common constant
|
|
1922
|
+
static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
|
|
1923
|
+
std::vector<wgpu::ConstantEntry> constants(1);
|
|
1924
|
+
constants[0].key = "wg_size";
|
|
1925
|
+
constants[0].value = wg_size;
|
|
1926
|
+
return constants;
|
|
1927
|
+
}
|
|
1928
|
+
|
|
1929
|
+
static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
|
|
1930
|
+
// we use the maximum workgroup size for the memset pipeline
|
|
1931
|
+
size_t max_threads = WEBGPU_MAX_WG_SIZE * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
|
|
1932
|
+
// Size the bytes_per_thread so that the largest buffer size can be handled
|
|
1933
|
+
webgpu_ctx->memset_bytes_per_thread = CEIL_DIV(webgpu_ctx->limits.maxStorageBufferBindingSize, max_threads);
|
|
1934
|
+
std::vector<wgpu::ConstantEntry> constants(2);
|
|
1935
|
+
constants[0].key = "wg_size";
|
|
1936
|
+
constants[0].value = WEBGPU_MAX_WG_SIZE;
|
|
1937
|
+
constants[1].key = "bytes_per_thread";
|
|
1938
|
+
constants[1].value = webgpu_ctx->memset_bytes_per_thread;
|
|
1939
|
+
webgpu_ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_memset, "memset", constants);
|
|
1940
|
+
}
|
|
1941
|
+
|
|
1942
|
+
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
|
1943
|
+
// Q4/Q5/Q8 classic quantizations
|
|
1944
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
|
|
1945
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
|
|
1946
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] =
|
|
1947
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
|
|
1948
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] =
|
|
1949
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
|
|
1950
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] =
|
|
1951
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
|
|
1952
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] =
|
|
1953
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
|
|
1954
|
+
|
|
1955
|
+
// K-quantizations
|
|
1956
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] =
|
|
1957
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
|
|
1958
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] =
|
|
1959
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
|
|
1960
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] =
|
|
1961
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
|
|
1962
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] =
|
|
1963
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
|
|
1964
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] =
|
|
1965
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
|
|
1966
|
+
|
|
1967
|
+
// IQ quantizations (2-, 3-, 4-bit variants)
|
|
1968
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] =
|
|
1969
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
|
|
1970
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] =
|
|
1971
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
|
|
1972
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] =
|
|
1973
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
|
|
1974
|
+
|
|
1975
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] =
|
|
1976
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
|
|
1977
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] =
|
|
1978
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
|
|
1979
|
+
|
|
1980
|
+
// 1-bit and 4-bit IQ variants
|
|
1981
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] =
|
|
1982
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
|
|
1983
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] =
|
|
1984
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
|
|
1985
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] =
|
|
1986
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
|
|
1987
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] =
|
|
1988
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
|
|
1989
|
+
|
|
1990
|
+
std::string proc_mul_mat_f32_f32;
|
|
1991
|
+
std::string proc_mul_mat_f32_f32_vec;
|
|
1992
|
+
std::string proc_mul_mat_f16_f32;
|
|
1993
|
+
std::string proc_mul_mat_f16_f32_vec;
|
|
1994
|
+
std::string proc_mul_mat_f16_f16;
|
|
1995
|
+
std::string proc_mul_mat_f16_f16_vec;
|
|
1996
|
+
std::string proc_mul_mat_q4_0_f32;
|
|
1997
|
+
std::string proc_mul_mat_q4_0_f32_vec;
|
|
1998
|
+
|
|
1999
|
+
std::vector<wgpu::ConstantEntry> mul_mat_constants;
|
|
2000
|
+
#ifndef __EMSCRIPTEN__
|
|
2001
|
+
if (webgpu_ctx->supports_subgroup_matrix) {
|
|
2002
|
+
std::map<std::string, std::string> sg_matrix_repls;
|
|
2003
|
+
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size);
|
|
2004
|
+
sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
|
|
2005
|
+
sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
|
|
2006
|
+
sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
|
|
2007
|
+
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
|
|
2008
|
+
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
|
|
2009
|
+
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m);
|
|
2010
|
+
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n);
|
|
2011
|
+
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k);
|
|
2012
|
+
|
|
2013
|
+
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
|
|
2014
|
+
proc_mul_mat_f32_f32_vec =
|
|
2015
|
+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
|
|
2016
|
+
proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
|
|
2017
|
+
proc_mul_mat_f16_f32_vec =
|
|
2018
|
+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
|
|
2019
|
+
proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
|
|
2020
|
+
proc_mul_mat_f16_f16_vec =
|
|
2021
|
+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
|
|
2022
|
+
proc_mul_mat_q4_0_f32 =
|
|
2023
|
+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
|
|
2024
|
+
proc_mul_mat_q4_0_f32_vec =
|
|
2025
|
+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
|
|
2026
|
+
} else {
|
|
2027
|
+
#endif
|
|
2028
|
+
mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
|
|
2029
|
+
mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
|
|
2030
|
+
mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
|
|
2031
|
+
|
|
2032
|
+
std::map<std::string, std::string> reg_repls;
|
|
2033
|
+
reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
|
|
2034
|
+
reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
|
|
2035
|
+
|
|
2036
|
+
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
|
|
2037
|
+
proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
|
|
2038
|
+
proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
|
|
2039
|
+
proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
|
|
2040
|
+
proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
|
|
2041
|
+
proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
|
|
2042
|
+
proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
|
|
2043
|
+
proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
|
|
2044
|
+
#ifndef __EMSCRIPTEN__
|
|
2045
|
+
}
|
|
2046
|
+
#endif
|
|
2047
|
+
|
|
2048
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
|
|
2049
|
+
webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
|
|
2050
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
|
2051
|
+
webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
|
|
2052
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
|
|
2053
|
+
webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
|
|
2054
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
|
2055
|
+
webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
|
|
2056
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
|
|
2057
|
+
webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
|
|
2058
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
|
|
2059
|
+
webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
|
|
2060
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
|
|
2061
|
+
webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
|
|
2062
|
+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
|
2063
|
+
webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
|
|
2064
|
+
|
|
2065
|
+
std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
|
|
2066
|
+
mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
|
|
2067
|
+
mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
|
2068
|
+
mul_mat_vec_constants[1].key = "TILE_K";
|
|
2069
|
+
mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
|
|
2070
|
+
mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG";
|
|
2071
|
+
mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
|
|
2072
|
+
|
|
2073
|
+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
|
|
2074
|
+
webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
|
|
2075
|
+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
|
2076
|
+
webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
|
|
2077
|
+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
|
|
2078
|
+
webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
|
|
2079
|
+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
|
2080
|
+
webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
|
|
2081
|
+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
|
|
2082
|
+
webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
|
|
2083
|
+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
|
|
2084
|
+
webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
|
|
2085
|
+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
|
|
2086
|
+
webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
|
|
2087
|
+
}
|
|
2088
|
+
|
|
2089
|
+
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
|
2090
|
+
webgpu_ctx->set_rows_pipelines[0][0] = ggml_webgpu_create_pipeline(
|
|
2091
|
+
webgpu_ctx->device, wgsl_set_rows_f16, "set_rows_f16", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
|
|
2092
|
+
webgpu_ctx->set_rows_pipelines[0][1] = ggml_webgpu_create_pipeline(
|
|
2093
|
+
webgpu_ctx->device, wgsl_set_rows_f16_vec, "set_rows_f16_vec", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
|
|
2094
|
+
}
|
|
2095
|
+
|
|
2096
|
+
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
|
2097
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
|
2098
|
+
|
|
2099
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] =
|
|
2100
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
|
|
2101
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] =
|
|
2102
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
|
|
2103
|
+
|
|
2104
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] =
|
|
2105
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
|
|
2106
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] =
|
|
2107
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
|
|
2108
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] =
|
|
2109
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
|
|
2110
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] =
|
|
2111
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
|
|
2112
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] =
|
|
2113
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
|
|
2114
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] =
|
|
2115
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
|
|
2116
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] =
|
|
2117
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
|
|
2118
|
+
|
|
2119
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] =
|
|
2120
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
|
|
2121
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] =
|
|
2122
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
|
|
2123
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] =
|
|
2124
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
|
|
2125
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] =
|
|
2126
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
|
|
2127
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] =
|
|
2128
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
|
|
2129
|
+
|
|
2130
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] =
|
|
2131
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
|
|
2132
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] =
|
|
2133
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
|
|
2134
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] =
|
|
2135
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
|
|
2136
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] =
|
|
2137
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
|
|
2138
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] =
|
|
2139
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
|
|
2140
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] =
|
|
2141
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
|
|
2142
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] =
|
|
2143
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
|
|
2144
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] =
|
|
2145
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
|
|
2146
|
+
webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] =
|
|
2147
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
|
|
2148
|
+
}
|
|
2149
|
+
|
|
2150
|
+
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
|
2151
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
|
2152
|
+
|
|
2153
|
+
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
|
|
2154
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
|
|
2155
|
+
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
|
|
2156
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
|
|
2157
|
+
webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
|
|
2158
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
|
|
2159
|
+
webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
|
|
2160
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
|
|
2161
|
+
}
|
|
2162
|
+
|
|
2163
|
+
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
|
2164
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
|
2165
|
+
|
|
2166
|
+
webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] =
|
|
2167
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32, "add_f32", constants);
|
|
2168
|
+
webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] =
|
|
2169
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16, "add_f16", constants);
|
|
2170
|
+
webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] =
|
|
2171
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants);
|
|
2172
|
+
webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] =
|
|
2173
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants);
|
|
2174
|
+
}
|
|
2175
|
+
|
|
2176
|
+
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
|
|
2177
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
|
2178
|
+
|
|
2179
|
+
webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] =
|
|
2180
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32, "sub_f32", constants);
|
|
2181
|
+
webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] =
|
|
2182
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16, "sub_f16", constants);
|
|
2183
|
+
webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] =
|
|
2184
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants);
|
|
2185
|
+
webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] =
|
|
2186
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants);
|
|
2187
|
+
}
|
|
2188
|
+
|
|
2189
|
+
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
|
2190
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
|
2191
|
+
|
|
2192
|
+
webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] =
|
|
2193
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32, "mul_f32", constants);
|
|
2194
|
+
webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] =
|
|
2195
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16, "mul_f16", constants);
|
|
2196
|
+
webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] =
|
|
2197
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants);
|
|
2198
|
+
webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] =
|
|
2199
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants);
|
|
2200
|
+
}
|
|
2201
|
+
|
|
2202
|
+
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
|
|
2203
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
|
2204
|
+
|
|
2205
|
+
webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] =
|
|
2206
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32, "div_f32", constants);
|
|
2207
|
+
webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] =
|
|
2208
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16, "div_f16", constants);
|
|
2209
|
+
webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] =
|
|
2210
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants);
|
|
2211
|
+
webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] =
|
|
2212
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants);
|
|
2213
|
+
}
|
|
2214
|
+
|
|
2215
|
+
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
|
2216
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
|
|
2217
|
+
|
|
2218
|
+
webgpu_ctx->rms_norm_pipelines[0] =
|
|
2219
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm, "rms_norm", constants);
|
|
2220
|
+
webgpu_ctx->rms_norm_pipelines[1] =
|
|
2221
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
|
|
2222
|
+
}
|
|
2223
|
+
|
|
2224
|
+
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
|
|
2225
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
|
2226
|
+
|
|
2227
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
|
|
2228
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32, "rope_f32", constants);
|
|
2229
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] =
|
|
2230
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
|
|
2231
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
|
|
2232
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
|
|
2233
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] =
|
|
2234
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
|
|
2235
|
+
|
|
2236
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
|
|
2237
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16, "rope_f16", constants);
|
|
2238
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] =
|
|
2239
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
|
|
2240
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
|
|
2241
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
|
|
2242
|
+
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] =
|
|
2243
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
|
|
2244
|
+
}
|
|
2245
|
+
|
|
2246
|
+
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
|
|
2247
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
|
2248
|
+
|
|
2249
|
+
// REGLU
|
|
2250
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
|
|
2251
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
|
|
2252
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
|
|
2253
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
|
|
2254
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
|
|
2255
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
|
|
2256
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
|
|
2257
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
|
|
2258
|
+
|
|
2259
|
+
// GEGLU
|
|
2260
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
|
|
2261
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
|
|
2262
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
|
|
2263
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
|
|
2264
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
|
|
2265
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
|
|
2266
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
|
|
2267
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
|
|
2268
|
+
|
|
2269
|
+
// SWIGLU
|
|
2270
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
|
|
2271
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
|
|
2272
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
|
|
2273
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
|
|
2274
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] =
|
|
2275
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
|
|
2276
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] =
|
|
2277
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
|
|
2278
|
+
|
|
2279
|
+
// SWIGLU_OAI
|
|
2280
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
|
|
2281
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
|
|
2282
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] =
|
|
2283
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
|
|
2284
|
+
|
|
2285
|
+
// GEGLU_ERF
|
|
2286
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
|
|
2287
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
|
|
2288
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
|
|
2289
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
|
|
2290
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] =
|
|
2291
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
|
|
2292
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] =
|
|
2293
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
|
|
2294
|
+
|
|
2295
|
+
// GEGLU_QUICK
|
|
2296
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
|
|
2297
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
|
|
2298
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
|
|
2299
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
|
|
2300
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] =
|
|
2301
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
|
|
2302
|
+
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] =
|
|
2303
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
|
|
2304
|
+
}
|
|
2305
|
+
|
|
2306
|
+
static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) {
|
|
2307
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
|
2308
|
+
|
|
2309
|
+
// ABS
|
|
2310
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0] =
|
|
2311
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f32, "abs_f32", constants);
|
|
2312
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0] =
|
|
2313
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f16, "abs_f16", constants);
|
|
2314
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1] =
|
|
2315
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f32, "abs_inplace_f32", constants);
|
|
2316
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1] =
|
|
2317
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f16, "abs_inplace_f16", constants);
|
|
2318
|
+
|
|
2319
|
+
// SGN
|
|
2320
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0] =
|
|
2321
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f32, "sgn_f32", constants);
|
|
2322
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0] =
|
|
2323
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f16, "sgn_f16", constants);
|
|
2324
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1] =
|
|
2325
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f32, "sgn_inplace_f32", constants);
|
|
2326
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1] =
|
|
2327
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f16, "sgn_inplace_f16", constants);
|
|
2328
|
+
|
|
2329
|
+
// NEG
|
|
2330
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0] =
|
|
2331
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f32, "neg_f32", constants);
|
|
2332
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0] =
|
|
2333
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f16, "neg_f16", constants);
|
|
2334
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1] =
|
|
2335
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f32, "neg_inplace_f32", constants);
|
|
2336
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1] =
|
|
2337
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f16, "neg_inplace_f16", constants);
|
|
2338
|
+
|
|
2339
|
+
// STEP
|
|
2340
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0] =
|
|
2341
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f32, "step_f32", constants);
|
|
2342
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0] =
|
|
2343
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f16, "step_f16", constants);
|
|
2344
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1] =
|
|
2345
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f32, "step_inplace_f32", constants);
|
|
2346
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1] =
|
|
2347
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f16, "step_inplace_f16", constants);
|
|
2348
|
+
|
|
2349
|
+
// TANH
|
|
2350
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0] =
|
|
2351
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f32, "tanh_f32", constants);
|
|
2352
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0] =
|
|
2353
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f16, "tanh_f16", constants);
|
|
2354
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1] =
|
|
2355
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f32, "tanh_inplace_f32", constants);
|
|
2356
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1] =
|
|
2357
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f16, "tanh_inplace_f16", constants);
|
|
2358
|
+
|
|
2359
|
+
// ELU
|
|
2360
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0] =
|
|
2361
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f32, "elu_f32", constants);
|
|
2362
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0] =
|
|
2363
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f16, "elu_f16", constants);
|
|
2364
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1] =
|
|
2365
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f32, "elu_inplace_f32", constants);
|
|
2366
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1] =
|
|
2367
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f16, "elu_inplace_f16", constants);
|
|
2368
|
+
|
|
2369
|
+
// RELU
|
|
2370
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0] =
|
|
2371
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f32, "relu_f32", constants);
|
|
2372
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0] =
|
|
2373
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f16, "relu_f16", constants);
|
|
2374
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1] =
|
|
2375
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f32, "relu_inplace_f32", constants);
|
|
2376
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1] =
|
|
2377
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f16, "relu_inplace_f16", constants);
|
|
2378
|
+
|
|
2379
|
+
// SIGMOID
|
|
2380
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0] =
|
|
2381
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f32, "sigmoid_f32", constants);
|
|
2382
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0] =
|
|
2383
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f16, "sigmoid_f16", constants);
|
|
2384
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1] =
|
|
2385
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f32, "sigmoid_inplace_f32", constants);
|
|
2386
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1] =
|
|
2387
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f16, "sigmoid_inplace_f16", constants);
|
|
2388
|
+
|
|
2389
|
+
// GELU
|
|
2390
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0] =
|
|
2391
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f32, "gelu_f32", constants);
|
|
2392
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0] =
|
|
2393
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f16, "gelu_f16", constants);
|
|
2394
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1] =
|
|
2395
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f32, "gelu_inplace_f32", constants);
|
|
2396
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1] =
|
|
2397
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f16, "gelu_inplace_f16", constants);
|
|
2398
|
+
|
|
2399
|
+
// GELU_QUICK
|
|
2400
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0] =
|
|
2401
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f32, "gelu_quick_f32", constants);
|
|
2402
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0] =
|
|
2403
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f16, "gelu_quick_f16", constants);
|
|
2404
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
|
2405
|
+
webgpu_ctx->device, wgsl_gelu_quick_inplace_f32, "gelu_quick_inplace_f32", constants);
|
|
2406
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
|
|
2407
|
+
webgpu_ctx->device, wgsl_gelu_quick_inplace_f16, "gelu_quick_inplace_f16", constants);
|
|
2408
|
+
|
|
2409
|
+
// SILU
|
|
2410
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0] =
|
|
2411
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f32, "silu_f32", constants);
|
|
2412
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0] =
|
|
2413
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f16, "silu_f16", constants);
|
|
2414
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1] =
|
|
2415
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f32, "silu_inplace_f32", constants);
|
|
2416
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1] =
|
|
2417
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f16, "silu_inplace_f16", constants);
|
|
2418
|
+
|
|
2419
|
+
// HARDSWISH
|
|
2420
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0] =
|
|
2421
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f32, "hardswish_f32", constants);
|
|
2422
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0] =
|
|
2423
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f16, "hardswish_f16", constants);
|
|
2424
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1] =
|
|
2425
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f32, "hardswish_inplace_f32", constants);
|
|
2426
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1] =
|
|
2427
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f16, "hardswish_inplace_f16", constants);
|
|
2428
|
+
|
|
2429
|
+
// HARDSIGMOID
|
|
2430
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0] =
|
|
2431
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants);
|
|
2432
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0] =
|
|
2433
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants);
|
|
2434
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
|
|
2435
|
+
webgpu_ctx->device, wgsl_hardsigmoid_inplace_f32, "hardsigmoid_inplace_f32", constants);
|
|
2436
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
|
|
2437
|
+
webgpu_ctx->device, wgsl_hardsigmoid_inplace_f16, "hardsigmoid_inplace_f16", constants);
|
|
2438
|
+
|
|
2439
|
+
// EXP
|
|
2440
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0] =
|
|
2441
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f32, "exp_f32", constants);
|
|
2442
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0] =
|
|
2443
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f16, "exp_f16", constants);
|
|
2444
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1] =
|
|
2445
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f32, "exp_inplace_f32", constants);
|
|
2446
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1] =
|
|
2447
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f16, "exp_inplace_f16", constants);
|
|
2448
|
+
|
|
2449
|
+
// GELU_ERF
|
|
2450
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0] =
|
|
2451
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f32, "gelu_erf_f32", constants);
|
|
2452
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0] =
|
|
2453
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f16, "gelu_erf_f16", constants);
|
|
2454
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1] =
|
|
2455
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f32, "gelu_erf_inplace_f32", constants);
|
|
2456
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1] =
|
|
2457
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f16, "gelu_erf_inplace_f16", constants);
|
|
2458
|
+
|
|
2459
|
+
// XIELU
|
|
2460
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0] =
|
|
2461
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f32, "xielu_f32", constants);
|
|
2462
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0] =
|
|
2463
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f16, "xielu_f16", constants);
|
|
2464
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1] =
|
|
2465
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f32, "xielu_inplace_f32", constants);
|
|
2466
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1] =
|
|
2467
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f16, "xielu_inplace_f16", constants);
|
|
2468
|
+
|
|
2469
|
+
// CEIL
|
|
2470
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][0] =
|
|
2471
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f32, "ceil_f32", constants);
|
|
2472
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][0] =
|
|
2473
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f16, "ceil_f16", constants);
|
|
2474
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][1] =
|
|
2475
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f32, "ceil_inplace_f32", constants);
|
|
2476
|
+
webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][1] =
|
|
2477
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f16, "ceil_inplace_f16", constants);
|
|
2478
|
+
}
|
|
2479
|
+
|
|
2480
|
+
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
|
|
2481
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
|
2482
|
+
|
|
2483
|
+
webgpu_ctx->scale_pipelines[0] =
|
|
2484
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32, "scale_f32", constants);
|
|
2485
|
+
webgpu_ctx->scale_pipelines[1] =
|
|
2486
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32_inplace, "scale_f32_inplace", constants);
|
|
2487
|
+
}
|
|
2488
|
+
|
|
2489
|
+
static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
|
|
2490
|
+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
|
|
2491
|
+
|
|
2492
|
+
// f32 (no mask)
|
|
2493
|
+
webgpu_ctx->soft_max_pipelines[2][0][0] =
|
|
2494
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
|
|
2495
|
+
webgpu_ctx->soft_max_pipelines[2][0][1] =
|
|
2496
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
|
|
2497
|
+
webgpu_ctx->soft_max_pipelines[2][1][0] =
|
|
2498
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
|
|
2499
|
+
webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
|
|
2500
|
+
webgpu_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
|
|
2501
|
+
|
|
2502
|
+
// f32 mask (mask_type = 0)
|
|
2503
|
+
webgpu_ctx->soft_max_pipelines[0][0][0] =
|
|
2504
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
|
|
2505
|
+
webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
|
|
2506
|
+
webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
|
|
2507
|
+
webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
|
|
2508
|
+
webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
|
|
2509
|
+
webgpu_ctx->soft_max_pipelines[0][1][1] = ggml_webgpu_create_pipeline(
|
|
2510
|
+
webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", constants);
|
|
2511
|
+
|
|
2512
|
+
// f16 mask (mask_type = 1)
|
|
2513
|
+
webgpu_ctx->soft_max_pipelines[1][0][0] =
|
|
2514
|
+
ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
|
|
2515
|
+
webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
|
|
2516
|
+
webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
|
|
2517
|
+
webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
|
|
2518
|
+
webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
|
|
2519
|
+
webgpu_ctx->soft_max_pipelines[1][1][1] = ggml_webgpu_create_pipeline(
|
|
2520
|
+
webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants);
|
|
2521
|
+
}
|
|
2522
|
+
|
|
2523
|
+
// TODO: move most initialization logic here
|
|
2524
|
+
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
|
2525
|
+
GGML_UNUSED(params);
|
|
2526
|
+
|
|
2527
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
|
|
2528
|
+
|
|
2529
|
+
ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
|
2530
|
+
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
|
|
2531
|
+
|
|
2532
|
+
static ggml_backend_webgpu_context backend_ctx;
|
|
2533
|
+
backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
|
|
2534
|
+
backend_ctx.webgpu_ctx = webgpu_ctx;
|
|
2535
|
+
|
|
2536
|
+
// See GGML Backend Interface section
|
|
2537
|
+
static ggml_backend backend = {
|
|
2538
|
+
/* .guid = */ ggml_backend_webgpu_guid(),
|
|
2539
|
+
/* .interface = */ ggml_backend_webgpu_i,
|
|
2540
|
+
/* .device = */ dev,
|
|
2541
|
+
/* .context = */ &backend_ctx,
|
|
2542
|
+
};
|
|
2543
|
+
return &backend;
|
|
2544
|
+
}
|
|
2545
|
+
|
|
2546
|
+
static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
|
|
2547
|
+
// See GGML Backend Buffer Type Interface section
|
|
2548
|
+
|
|
2549
|
+
static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
|
|
2550
|
+
/* .iface = */ {
|
|
2551
|
+
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
|
|
2552
|
+
/* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
|
|
2553
|
+
/* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
|
|
2554
|
+
/* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
|
|
2555
|
+
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
|
2556
|
+
/* .is_host = */ NULL, // defaults to false
|
|
2557
|
+
},
|
|
2558
|
+
/* .device = */
|
|
2559
|
+
dev,
|
|
2560
|
+
/* .context = */ NULL,
|
|
2561
|
+
};
|
|
2562
|
+
|
|
2563
|
+
return &ggml_backend_webgpu_buffer_type;
|
|
2564
|
+
}
|
|
2565
|
+
|
|
2566
|
+
static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
|
2567
|
+
GGML_UNUSED(dev);
|
|
2568
|
+
return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
|
|
2569
|
+
}
|
|
2570
|
+
|
|
2571
|
+
static bool ggml_webgpu_supported_qtype(ggml_type type) {
|
|
2572
|
+
switch (type) {
|
|
2573
|
+
case GGML_TYPE_Q4_0:
|
|
2574
|
+
case GGML_TYPE_Q4_1:
|
|
2575
|
+
case GGML_TYPE_Q5_0:
|
|
2576
|
+
case GGML_TYPE_Q5_1:
|
|
2577
|
+
case GGML_TYPE_Q8_0:
|
|
2578
|
+
case GGML_TYPE_Q2_K:
|
|
2579
|
+
case GGML_TYPE_Q3_K:
|
|
2580
|
+
case GGML_TYPE_Q4_K:
|
|
2581
|
+
case GGML_TYPE_Q5_K:
|
|
2582
|
+
case GGML_TYPE_Q6_K:
|
|
2583
|
+
case GGML_TYPE_IQ2_XXS:
|
|
2584
|
+
case GGML_TYPE_IQ2_XS:
|
|
2585
|
+
case GGML_TYPE_IQ2_S:
|
|
2586
|
+
case GGML_TYPE_IQ3_XXS:
|
|
2587
|
+
case GGML_TYPE_IQ3_S:
|
|
2588
|
+
case GGML_TYPE_IQ1_S:
|
|
2589
|
+
case GGML_TYPE_IQ1_M:
|
|
2590
|
+
case GGML_TYPE_IQ4_NL:
|
|
2591
|
+
case GGML_TYPE_IQ4_XS:
|
|
2592
|
+
return true;
|
|
2593
|
+
default:
|
|
2594
|
+
return false;
|
|
2595
|
+
}
|
|
2596
|
+
}
|
|
2597
|
+
|
|
2598
|
+
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
|
2599
|
+
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
|
2600
|
+
|
|
2601
|
+
webgpu_context webgpu_ctx = ctx->webgpu_ctx;
|
|
2602
|
+
|
|
2603
|
+
ggml_tensor * src0 = op->src[0];
|
|
2604
|
+
ggml_tensor * src1 = op->src[1];
|
|
2605
|
+
ggml_tensor * src2 = op->src[2];
|
|
2606
|
+
|
|
2607
|
+
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
|
|
2608
|
+
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
|
|
2609
|
+
(src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
|
|
2610
|
+
(src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
|
|
2611
|
+
return false;
|
|
2612
|
+
}
|
|
2613
|
+
|
|
2614
|
+
bool supports_op = false;
|
|
2615
|
+
switch (op->op) {
|
|
2616
|
+
case GGML_OP_NONE:
|
|
2617
|
+
case GGML_OP_VIEW:
|
|
2618
|
+
case GGML_OP_PERMUTE:
|
|
2619
|
+
case GGML_OP_TRANSPOSE:
|
|
2620
|
+
case GGML_OP_RESHAPE:
|
|
2621
|
+
supports_op = true;
|
|
2622
|
+
break;
|
|
2623
|
+
case GGML_OP_ADD:
|
|
2624
|
+
case GGML_OP_SUB:
|
|
2625
|
+
case GGML_OP_MUL:
|
|
2626
|
+
case GGML_OP_DIV:
|
|
2627
|
+
// TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
|
|
2628
|
+
// see https://github.com/ggml-org/llama.cpp/pull/16857
|
|
2629
|
+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
|
|
2630
|
+
(src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
|
2631
|
+
break;
|
|
2632
|
+
case GGML_OP_CPY:
|
|
2633
|
+
case GGML_OP_CONT:
|
|
2634
|
+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
|
2635
|
+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
2636
|
+
break;
|
|
2637
|
+
case GGML_OP_SET_ROWS:
|
|
2638
|
+
supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64);
|
|
2639
|
+
break;
|
|
2640
|
+
case GGML_OP_GET_ROWS:
|
|
2641
|
+
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
|
|
2642
|
+
ggml_webgpu_supported_qtype(src0->type)) {
|
|
2643
|
+
supports_op = (op->type == GGML_TYPE_F32);
|
|
2644
|
+
}
|
|
2645
|
+
break;
|
|
2646
|
+
case GGML_OP_MUL_MAT:
|
|
2647
|
+
{
|
|
2648
|
+
switch (src1->type) {
|
|
2649
|
+
case GGML_TYPE_F16:
|
|
2650
|
+
supports_op |= (src0->type == GGML_TYPE_F16);
|
|
2651
|
+
break;
|
|
2652
|
+
case GGML_TYPE_F32:
|
|
2653
|
+
switch (src0->type) {
|
|
2654
|
+
case GGML_TYPE_F32:
|
|
2655
|
+
case GGML_TYPE_F16:
|
|
2656
|
+
case GGML_TYPE_Q4_0:
|
|
2657
|
+
case GGML_TYPE_Q4_1:
|
|
2658
|
+
case GGML_TYPE_Q5_0:
|
|
2659
|
+
case GGML_TYPE_Q5_1:
|
|
2660
|
+
case GGML_TYPE_Q8_0:
|
|
2661
|
+
case GGML_TYPE_Q2_K:
|
|
2662
|
+
case GGML_TYPE_Q3_K:
|
|
2663
|
+
case GGML_TYPE_Q4_K:
|
|
2664
|
+
case GGML_TYPE_Q5_K:
|
|
2665
|
+
case GGML_TYPE_Q6_K:
|
|
2666
|
+
case GGML_TYPE_IQ2_XXS:
|
|
2667
|
+
case GGML_TYPE_IQ2_XS:
|
|
2668
|
+
case GGML_TYPE_IQ2_S:
|
|
2669
|
+
case GGML_TYPE_IQ3_XXS:
|
|
2670
|
+
case GGML_TYPE_IQ3_S:
|
|
2671
|
+
case GGML_TYPE_IQ1_S:
|
|
2672
|
+
case GGML_TYPE_IQ1_M:
|
|
2673
|
+
case GGML_TYPE_IQ4_NL:
|
|
2674
|
+
case GGML_TYPE_IQ4_XS:
|
|
2675
|
+
supports_op = true;
|
|
2676
|
+
break;
|
|
2677
|
+
default:
|
|
2678
|
+
break;
|
|
2679
|
+
}
|
|
2680
|
+
default:
|
|
2681
|
+
break;
|
|
2682
|
+
}
|
|
2683
|
+
break;
|
|
2684
|
+
}
|
|
2685
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
|
2686
|
+
{
|
|
2687
|
+
if (!webgpu_ctx->supports_subgroup_matrix) {
|
|
2688
|
+
break;
|
|
2689
|
+
}
|
|
2690
|
+
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
|
2691
|
+
size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize;
|
|
2692
|
+
const bool has_mask = op->src[3] != nullptr;
|
|
2693
|
+
const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 &&
|
|
2694
|
+
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
|
|
2695
|
+
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
|
2696
|
+
webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0],
|
|
2697
|
+
has_mask, kv_direct);
|
|
2698
|
+
if (min_bytes > limit_bytes) {
|
|
2699
|
+
break;
|
|
2700
|
+
}
|
|
2701
|
+
|
|
2702
|
+
supports_op = src0->type == GGML_TYPE_F32 &&
|
|
2703
|
+
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
|
|
2704
|
+
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
|
|
2705
|
+
src2->type == src1->type && op->type == GGML_TYPE_F32;
|
|
2706
|
+
break;
|
|
2707
|
+
}
|
|
2708
|
+
case GGML_OP_RMS_NORM:
|
|
2709
|
+
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
|
2710
|
+
break;
|
|
2711
|
+
case GGML_OP_ROPE:
|
|
2712
|
+
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
|
|
2713
|
+
break;
|
|
2714
|
+
case GGML_OP_GLU:
|
|
2715
|
+
switch (ggml_get_glu_op(op)) {
|
|
2716
|
+
case GGML_GLU_OP_REGLU:
|
|
2717
|
+
case GGML_GLU_OP_GEGLU:
|
|
2718
|
+
case GGML_GLU_OP_SWIGLU:
|
|
2719
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
2720
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
2721
|
+
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
|
|
2722
|
+
break;
|
|
2723
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
2724
|
+
supports_op = op->type == GGML_TYPE_F32;
|
|
2725
|
+
break;
|
|
2726
|
+
default:
|
|
2727
|
+
break;
|
|
2728
|
+
}
|
|
2729
|
+
break;
|
|
2730
|
+
case GGML_OP_SCALE:
|
|
2731
|
+
supports_op = op->type == GGML_TYPE_F32;
|
|
2732
|
+
break;
|
|
2733
|
+
case GGML_OP_SOFT_MAX:
|
|
2734
|
+
supports_op = op->type == GGML_TYPE_F32;
|
|
2735
|
+
break;
|
|
2736
|
+
case GGML_OP_UNARY:
|
|
2737
|
+
{
|
|
2738
|
+
const ggml_unary_op UNARY_OP = ggml_get_unary_op(op);
|
|
2739
|
+
|
|
2740
|
+
switch (UNARY_OP) {
|
|
2741
|
+
case GGML_UNARY_OP_ABS:
|
|
2742
|
+
case GGML_UNARY_OP_SGN:
|
|
2743
|
+
case GGML_UNARY_OP_NEG:
|
|
2744
|
+
case GGML_UNARY_OP_STEP:
|
|
2745
|
+
case GGML_UNARY_OP_TANH:
|
|
2746
|
+
case GGML_UNARY_OP_ELU:
|
|
2747
|
+
case GGML_UNARY_OP_RELU:
|
|
2748
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
2749
|
+
case GGML_UNARY_OP_GELU:
|
|
2750
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
|
2751
|
+
case GGML_UNARY_OP_SILU:
|
|
2752
|
+
case GGML_UNARY_OP_HARDSWISH:
|
|
2753
|
+
case GGML_UNARY_OP_HARDSIGMOID:
|
|
2754
|
+
case GGML_UNARY_OP_EXP:
|
|
2755
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
2756
|
+
case GGML_UNARY_OP_XIELU:
|
|
2757
|
+
case GGML_UNARY_OP_CEIL:
|
|
2758
|
+
supports_op = supports_op =
|
|
2759
|
+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
|
2760
|
+
break;
|
|
2761
|
+
default:
|
|
2762
|
+
break;
|
|
2763
|
+
}
|
|
2764
|
+
}
|
|
2765
|
+
break;
|
|
2766
|
+
|
|
2767
|
+
default:
|
|
2768
|
+
break;
|
|
2769
|
+
}
|
|
2770
|
+
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
|
|
2771
|
+
(src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
|
|
2772
|
+
(src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
|
|
2773
|
+
(src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
|
|
2774
|
+
supports_op = false;
|
|
2775
|
+
WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
|
|
2776
|
+
}
|
|
2777
|
+
|
|
2778
|
+
if (!supports_op) {
|
|
2779
|
+
WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
|
|
2780
|
+
<< ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
|
|
2781
|
+
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
|
|
2782
|
+
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
|
|
2783
|
+
} else {
|
|
2784
|
+
WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
|
|
2785
|
+
<< ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
|
|
2786
|
+
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
|
|
2787
|
+
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
|
|
2788
|
+
}
|
|
2789
|
+
return supports_op;
|
|
2790
|
+
}
|
|
2791
|
+
|
|
2792
|
+
static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
|
|
2793
|
+
/* .get_name = */ ggml_backend_webgpu_device_get_name,
|
|
2794
|
+
/* .get_description = */ ggml_backend_webgpu_device_get_description,
|
|
2795
|
+
/* .get_memory = */ ggml_backend_webgpu_device_get_memory,
|
|
2796
|
+
/* .get_type = */ ggml_backend_webgpu_device_get_type,
|
|
2797
|
+
/* .get_props = */ ggml_backend_webgpu_device_get_props,
|
|
2798
|
+
/* .init_backend = */ ggml_backend_webgpu_device_init,
|
|
2799
|
+
/* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type,
|
|
2800
|
+
/* .get_host_buffer_type = */ NULL,
|
|
2801
|
+
/* .buffer_from_host_ptr = */ NULL,
|
|
2802
|
+
/* .supports_op = */ ggml_backend_webgpu_device_supports_op,
|
|
2803
|
+
/* .supports_buft = */ ggml_backend_webgpu_device_supports_buft,
|
|
2804
|
+
/* .offload_op = */ NULL,
|
|
2805
|
+
/* .event_new = */ NULL,
|
|
2806
|
+
/* .event_free = */ NULL,
|
|
2807
|
+
/* .event_synchronize = */ NULL,
|
|
2808
|
+
};
|
|
2809
|
+
|
|
2810
|
+
/* End GGML Backend Device Interface */
|
|
2811
|
+
|
|
2812
|
+
/* GGML Backend Registration Interface */
|
|
2813
|
+
|
|
2814
|
+
static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {
|
|
2815
|
+
ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
|
|
2816
|
+
return ctx->name;
|
|
2817
|
+
}
|
|
2818
|
+
|
|
2819
|
+
static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
|
|
2820
|
+
ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
|
|
2821
|
+
return ctx->device_count;
|
|
2822
|
+
}
|
|
2823
|
+
|
|
2824
|
+
// TODO: Does this need to be thread safe? Is it only called once?
|
|
2825
|
+
// TODO: move most logic to device_init function so backend can be freed/initialized properly
|
|
2826
|
+
// Only one device is supported for now
|
|
2827
|
+
static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
|
2828
|
+
GGML_ASSERT(index == 0);
|
|
2829
|
+
WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
|
|
2830
|
+
|
|
2831
|
+
WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device);
|
|
2832
|
+
|
|
2833
|
+
ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
|
|
2834
|
+
|
|
2835
|
+
webgpu_context ctx = reg_ctx->webgpu_ctx;
|
|
2836
|
+
|
|
2837
|
+
wgpu::RequestAdapterOptions options = {};
|
|
2838
|
+
|
|
2839
|
+
#ifndef __EMSCRIPTEN__
|
|
2840
|
+
// TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
|
|
2841
|
+
const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
|
|
2842
|
+
wgpu::DawnTogglesDescriptor adapterTogglesDesc;
|
|
2843
|
+
adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
|
|
2844
|
+
adapterTogglesDesc.enabledToggleCount = 2;
|
|
2845
|
+
options.nextInChain = &adapterTogglesDesc;
|
|
2846
|
+
#endif
|
|
2847
|
+
|
|
2848
|
+
ctx->instance.WaitAny(ctx->instance.RequestAdapter(
|
|
2849
|
+
&options, wgpu::CallbackMode::AllowSpontaneous,
|
|
2850
|
+
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
|
2851
|
+
if (status != wgpu::RequestAdapterStatus::Success) {
|
|
2852
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
|
2853
|
+
return;
|
|
2854
|
+
}
|
|
2855
|
+
ctx->adapter = std::move(adapter);
|
|
2856
|
+
}),
|
|
2857
|
+
UINT64_MAX);
|
|
2858
|
+
GGML_ASSERT(ctx->adapter != nullptr);
|
|
2859
|
+
|
|
2860
|
+
ctx->adapter.GetLimits(&ctx->limits);
|
|
2861
|
+
|
|
2862
|
+
wgpu::AdapterInfo info{};
|
|
2863
|
+
#ifndef __EMSCRIPTEN__
|
|
2864
|
+
wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
|
|
2865
|
+
if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
|
2866
|
+
info.nextInChain = &subgroup_matrix_configs;
|
|
2867
|
+
}
|
|
2868
|
+
#endif
|
|
2869
|
+
ctx->adapter.GetInfo(&info);
|
|
2870
|
+
|
|
2871
|
+
wgpu::SupportedFeatures features;
|
|
2872
|
+
ctx->adapter.GetFeatures(&features);
|
|
2873
|
+
// we require f16 support
|
|
2874
|
+
GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
|
2875
|
+
|
|
2876
|
+
#ifndef __EMSCRIPTEN__
|
|
2877
|
+
// Only support square f16 matrices of size 8 or 16 for now
|
|
2878
|
+
bool valid_subgroup_matrix_config = false;
|
|
2879
|
+
if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
|
2880
|
+
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
|
|
2881
|
+
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
|
|
2882
|
+
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
|
|
2883
|
+
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
|
2884
|
+
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
|
|
2885
|
+
ctx->sg_mat_m = config.M;
|
|
2886
|
+
ctx->sg_mat_n = config.N;
|
|
2887
|
+
ctx->sg_mat_k = config.K;
|
|
2888
|
+
valid_subgroup_matrix_config = true;
|
|
2889
|
+
break;
|
|
2890
|
+
}
|
|
2891
|
+
}
|
|
2892
|
+
}
|
|
2893
|
+
|
|
2894
|
+
ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
|
|
2895
|
+
#endif
|
|
2896
|
+
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
|
|
2897
|
+
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
|
|
2898
|
+
ctx->max_subgroup_size = info.subgroupMaxSize;
|
|
2899
|
+
|
|
2900
|
+
// Initialize device
|
|
2901
|
+
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
|
|
2902
|
+
|
|
2903
|
+
#ifndef __EMSCRIPTEN__
|
|
2904
|
+
required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
|
|
2905
|
+
if (ctx->supports_subgroup_matrix) {
|
|
2906
|
+
required_features.push_back(wgpu::FeatureName::Subgroups);
|
|
2907
|
+
required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
|
|
2908
|
+
}
|
|
2909
|
+
#endif
|
|
2910
|
+
|
|
2911
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
2912
|
+
required_features.push_back(wgpu::FeatureName::TimestampQuery);
|
|
2913
|
+
#endif
|
|
2914
|
+
|
|
2915
|
+
wgpu::DeviceDescriptor dev_desc;
|
|
2916
|
+
dev_desc.requiredLimits = &ctx->limits;
|
|
2917
|
+
dev_desc.requiredFeatures = required_features.data();
|
|
2918
|
+
dev_desc.requiredFeatureCount = required_features.size();
|
|
2919
|
+
dev_desc.SetDeviceLostCallback(
|
|
2920
|
+
wgpu::CallbackMode::AllowSpontaneous,
|
|
2921
|
+
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
|
|
2922
|
+
GGML_UNUSED(device);
|
|
2923
|
+
GGML_UNUSED(reason);
|
|
2924
|
+
GGML_UNUSED(message);
|
|
2925
|
+
//TODO: uncomment once proper free logic is in place
|
|
2926
|
+
//GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
|
2927
|
+
//std::string(message).c_str());
|
|
2928
|
+
});
|
|
2929
|
+
dev_desc.SetUncapturedErrorCallback(
|
|
2930
|
+
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
|
|
2931
|
+
GGML_UNUSED(device);
|
|
2932
|
+
GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
|
|
2933
|
+
std::string(message).c_str());
|
|
2934
|
+
});
|
|
2935
|
+
|
|
2936
|
+
#ifndef __EMSCRIPTEN__
|
|
2937
|
+
// Enable Dawn-specific toggles to increase native performance
|
|
2938
|
+
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
|
|
2939
|
+
// only for native performance?
|
|
2940
|
+
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
|
|
2941
|
+
"disable_polyfills_on_integer_div_and_mod" };
|
|
2942
|
+
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
|
|
2943
|
+
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
|
|
2944
|
+
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
|
|
2945
|
+
deviceTogglesDesc.enabledToggleCount = 4;
|
|
2946
|
+
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
|
|
2947
|
+
deviceTogglesDesc.disabledToggleCount = 1;
|
|
2948
|
+
|
|
2949
|
+
dev_desc.nextInChain = &deviceTogglesDesc;
|
|
2950
|
+
#endif
|
|
2951
|
+
|
|
2952
|
+
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
|
|
2953
|
+
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
|
|
2954
|
+
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
|
|
2955
|
+
if (status != wgpu::RequestDeviceStatus::Success) {
|
|
2956
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n",
|
|
2957
|
+
std::string(message).c_str());
|
|
2958
|
+
return;
|
|
2959
|
+
}
|
|
2960
|
+
ctx->device = std::move(device);
|
|
2961
|
+
}),
|
|
2962
|
+
UINT64_MAX);
|
|
2963
|
+
GGML_ASSERT(ctx->device != nullptr);
|
|
2964
|
+
|
|
2965
|
+
// Initialize (compute) queue
|
|
2966
|
+
ctx->queue = ctx->device.GetQueue();
|
|
2967
|
+
|
|
2968
|
+
// Create buffer pool for shader parameters
|
|
2969
|
+
ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
|
2970
|
+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
|
2971
|
+
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
|
|
2972
|
+
|
|
2973
|
+
#ifdef GGML_WEBGPU_GPU_PROFILE
|
|
2974
|
+
// Initialize buffer pool for timestamp queries (profiling)
|
|
2975
|
+
ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS,
|
|
2976
|
+
WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
|
|
2977
|
+
wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
|
|
2978
|
+
wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
|
|
2979
|
+
#endif
|
|
2980
|
+
|
|
2981
|
+
ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
|
2982
|
+
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
|
2983
|
+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
|
|
2984
|
+
|
|
2985
|
+
ggml_webgpu_init_memset_pipeline(ctx);
|
|
2986
|
+
ggml_webgpu_init_mul_mat_pipeline(ctx);
|
|
2987
|
+
ggml_webgpu_init_set_rows_pipeline(ctx);
|
|
2988
|
+
ggml_webgpu_init_get_rows_pipeline(ctx);
|
|
2989
|
+
ggml_webgpu_init_cpy_pipeline(ctx);
|
|
2990
|
+
ggml_webgpu_init_add_pipeline(ctx);
|
|
2991
|
+
ggml_webgpu_init_sub_pipeline(ctx);
|
|
2992
|
+
ggml_webgpu_init_mul_pipeline(ctx);
|
|
2993
|
+
ggml_webgpu_init_div_pipeline(ctx);
|
|
2994
|
+
ggml_webgpu_init_rms_norm_pipeline(ctx);
|
|
2995
|
+
ggml_webgpu_init_rope_pipeline(ctx);
|
|
2996
|
+
ggml_webgpu_init_glu_pipeline(ctx);
|
|
2997
|
+
ggml_webgpu_init_scale_pipeline(ctx);
|
|
2998
|
+
ggml_webgpu_init_soft_max_pipeline(ctx);
|
|
2999
|
+
ggml_webgpu_init_unary_pipeline(ctx);
|
|
3000
|
+
|
|
3001
|
+
#ifdef GGML_WEBGPU_DEBUG
|
|
3002
|
+
// Initialize debug buffers
|
|
3003
|
+
ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
|
3004
|
+
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
|
|
3005
|
+
ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
|
|
3006
|
+
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
|
|
3007
|
+
#endif
|
|
3008
|
+
|
|
3009
|
+
static ggml_backend_webgpu_device_context device_ctx;
|
|
3010
|
+
device_ctx.webgpu_ctx = ctx;
|
|
3011
|
+
device_ctx.device_name = GGML_WEBGPU_NAME;
|
|
3012
|
+
device_ctx.device_desc = info.description;
|
|
3013
|
+
|
|
3014
|
+
GGML_LOG_INFO(
|
|
3015
|
+
"ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
|
|
3016
|
+
"device_desc: %s\n",
|
|
3017
|
+
info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
|
|
3018
|
+
std::string(info.device).c_str(), std::string(info.description).c_str());
|
|
3019
|
+
|
|
3020
|
+
// See GGML Backend Device Interface section
|
|
3021
|
+
static ggml_backend_device device = {
|
|
3022
|
+
/* .iface = */ ggml_backend_webgpu_device_i,
|
|
3023
|
+
/* .reg = */ reg,
|
|
3024
|
+
/* .context = */ &device_ctx,
|
|
3025
|
+
};
|
|
3026
|
+
|
|
3027
|
+
WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, ctx);
|
|
3028
|
+
return &device;
|
|
3029
|
+
}
|
|
3030
|
+
|
|
3031
|
+
static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
|
|
3032
|
+
/* .get_name = */ ggml_backend_webgpu_reg_get_name,
|
|
3033
|
+
/* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
|
|
3034
|
+
/* .get_device = */ ggml_backend_webgpu_reg_get_device,
|
|
3035
|
+
/* .get_proc_address = */ NULL,
|
|
3036
|
+
};
|
|
3037
|
+
|
|
3038
|
+
/* End GGML Backend Registration Interface */
|
|
3039
|
+
|
|
3040
|
+
ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
|
3041
|
+
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
|
|
3042
|
+
|
|
3043
|
+
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
|
|
3044
|
+
|
|
3045
|
+
static ggml_backend_webgpu_reg_context ctx;
|
|
3046
|
+
ctx.webgpu_ctx = webgpu_ctx;
|
|
3047
|
+
ctx.name = GGML_WEBGPU_NAME;
|
|
3048
|
+
ctx.device_count = 1;
|
|
3049
|
+
|
|
3050
|
+
wgpu::InstanceDescriptor instance_descriptor{};
|
|
3051
|
+
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
|
|
3052
|
+
instance_descriptor.requiredFeatures = instance_features.data();
|
|
3053
|
+
instance_descriptor.requiredFeatureCount = instance_features.size();
|
|
3054
|
+
|
|
3055
|
+
#ifndef __EMSCRIPTEN__
|
|
3056
|
+
const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
|
|
3057
|
+
wgpu::DawnTogglesDescriptor instanceTogglesDesc;
|
|
3058
|
+
instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
|
|
3059
|
+
instanceTogglesDesc.enabledToggleCount = 1;
|
|
3060
|
+
instance_descriptor.nextInChain = &instanceTogglesDesc;
|
|
3061
|
+
#endif
|
|
3062
|
+
|
|
3063
|
+
webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
|
|
3064
|
+
|
|
3065
|
+
#ifdef __EMSCRIPTEN__
|
|
3066
|
+
if (webgpu_ctx->instance == nullptr) {
|
|
3067
|
+
GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
|
|
3068
|
+
return nullptr;
|
|
3069
|
+
}
|
|
3070
|
+
#endif
|
|
3071
|
+
GGML_ASSERT(webgpu_ctx->instance != nullptr);
|
|
3072
|
+
|
|
3073
|
+
static ggml_backend_reg reg = {
|
|
3074
|
+
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
|
3075
|
+
/* .iface = */ ggml_backend_webgpu_reg_i,
|
|
3076
|
+
/* .context = */ &ctx,
|
|
3077
|
+
};
|
|
3078
|
+
return ®
|
|
3079
|
+
}
|
|
3080
|
+
|
|
3081
|
+
ggml_backend_t ggml_backend_webgpu_init(void) {
|
|
3082
|
+
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
|
|
3083
|
+
|
|
3084
|
+
return ggml_backend_webgpu_device_init(dev, nullptr);
|
|
3085
|
+
}
|
|
3086
|
+
|
|
3087
|
+
GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)
|