whispercpp 1.3.5 → 1.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE +1 -1
- data/README.md +99 -2
- data/ext/extconf.rb +1 -0
- data/ext/ruby_whisper.c +20 -4
- data/ext/ruby_whisper.h +30 -2
- data/ext/ruby_whisper_context.c +216 -124
- data/ext/ruby_whisper_context_params.c +163 -0
- data/ext/ruby_whisper_model.c +0 -1
- data/ext/ruby_whisper_params.c +0 -1
- data/ext/ruby_whisper_segment.c +0 -1
- data/ext/ruby_whisper_token.c +29 -9
- data/ext/ruby_whisper_transcribe.cpp +4 -1
- data/ext/ruby_whisper_vad_context.c +48 -1
- data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
- data/ext/ruby_whisper_vad_params.c +0 -1
- data/ext/ruby_whisper_vad_segment.c +0 -1
- data/ext/ruby_whisper_vad_segments.c +0 -1
- data/ext/sources/CMakeLists.txt +1 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/cmake/whisper-config.cmake.in +5 -40
- data/ext/sources/examples/bench/bench.cpp +23 -18
- data/ext/sources/examples/cli/cli.cpp +8 -0
- data/ext/sources/examples/common-ggml.cpp +2 -0
- data/ext/sources/examples/miniaudio.h +4507 -2131
- data/ext/sources/examples/server/server.cpp +18 -4
- data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
- data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
- data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
- data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
- data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
- data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
- data/ext/sources/examples/talk-llama/llama-context.h +27 -28
- data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
- data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
- data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
- data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
- data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
- data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
- data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
- data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
- data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
- data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
- data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
- data/ext/sources/examples/talk-llama/llama-model.h +72 -19
- data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
- data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
- data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
- data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
- data/ext/sources/examples/talk-llama/llama.cpp +76 -22
- data/ext/sources/examples/talk-llama/llama.h +63 -30
- data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
- data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
- data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
- data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
- data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
- data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
- data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
- data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
- data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
- data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
- data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
- data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
- data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
- data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
- data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
- data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
- data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
- data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
- data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
- data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
- data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
- data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
- data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/models.h +181 -46
- data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
- data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
- data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
- data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
- data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
- data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
- data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
- data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
- data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
- data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
- data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
- data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
- data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
- data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
- data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
- data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
- data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
- data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
- data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
- data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
- data/ext/sources/ggml/CMakeLists.txt +9 -3
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cann.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +5 -0
- data/ext/sources/ggml/include/ggml-openvino.h +37 -0
- data/ext/sources/ggml/include/ggml-opt.h +1 -1
- data/ext/sources/ggml/include/ggml-rpc.h +6 -1
- data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
- data/ext/sources/ggml/include/ggml.h +56 -9
- data/ext/sources/ggml/src/CMakeLists.txt +3 -0
- data/ext/sources/ggml/src/ggml-alloc.c +4 -9
- data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
- data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
- data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
- data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
- data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
- data/ext/sources/ggml/src/ggml-common.h +11 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
- data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
- data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
- data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
- data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
- data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
- data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
- data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
- data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
- data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
- data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
- data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
- data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
- data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
- data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
- data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
- data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
- data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
- data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
- data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
- data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
- data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
- data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
- data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
- data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
- data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
- data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
- data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
- data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
- data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
- data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
- data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
- data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
- data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
- data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
- data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
- data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
- data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
- data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
- data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
- data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
- data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
- data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
- data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
- data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
- data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
- data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
- data/ext/sources/ggml/src/ggml-impl.h +62 -0
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
- data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
- data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
- data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
- data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
- data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
- data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
- data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
- data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
- data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
- data/ext/sources/ggml/src/ggml-quants.c +96 -5
- data/ext/sources/ggml/src/ggml-quants.h +3 -0
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
- data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
- data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
- data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
- data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
- data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
- data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
- data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
- data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
- data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
- data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
- data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
- data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
- data/ext/sources/ggml/src/ggml.c +167 -33
- data/ext/sources/ggml/src/gguf.cpp +229 -44
- data/ext/sources/src/whisper.cpp +6 -28
- data/sig/whisper.rbs +43 -2
- data/test/test_context_params.rb +82 -0
- data/test/test_token.rb +11 -0
- data/test/test_vad_context.rb +58 -8
- data/test/test_whisper.rb +20 -0
- data/whispercpp.gemspec +1 -1
- metadata +240 -28
- data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
- data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
- data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
- data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
- data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
|
@@ -27,6 +27,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
|
|
27
27
|
#include <iostream>
|
|
28
28
|
#include <tuple>
|
|
29
29
|
#include <vector>
|
|
30
|
+
#include <deque>
|
|
30
31
|
#include <sstream>
|
|
31
32
|
#include <utility>
|
|
32
33
|
#include <memory>
|
|
@@ -92,6 +93,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
|
|
92
93
|
#define VK_VENDOR_ID_APPLE 0x106b
|
|
93
94
|
#define VK_VENDOR_ID_INTEL 0x8086
|
|
94
95
|
#define VK_VENDOR_ID_NVIDIA 0x10de
|
|
96
|
+
#define VK_VENDOR_ID_QUALCOMM 0x5143
|
|
95
97
|
|
|
96
98
|
#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
|
|
97
99
|
|
|
@@ -187,6 +189,11 @@ struct ggml_backend_vk_buffer_type_context {
|
|
|
187
189
|
|
|
188
190
|
struct vk_queue;
|
|
189
191
|
|
|
192
|
+
struct vk_command_buffer {
|
|
193
|
+
vk::CommandBuffer buf;
|
|
194
|
+
bool in_use = false;
|
|
195
|
+
};
|
|
196
|
+
|
|
190
197
|
// Stores command pool/buffers. There's an instance of this
|
|
191
198
|
// for each (context,queue) pair and for each (device,queue) pair.
|
|
192
199
|
struct vk_command_pool {
|
|
@@ -194,10 +201,16 @@ struct vk_command_pool {
|
|
|
194
201
|
void destroy(vk::Device& device);
|
|
195
202
|
|
|
196
203
|
vk::CommandPool pool;
|
|
197
|
-
|
|
198
|
-
|
|
204
|
+
// Using deque so the pointers to command buffers
|
|
205
|
+
// remain valid even if we add more
|
|
206
|
+
std::deque<vk_command_buffer> cmd_buffers;
|
|
199
207
|
|
|
200
208
|
vk_queue *q;
|
|
209
|
+
|
|
210
|
+
size_t buffers_in_use() const {
|
|
211
|
+
return std::count_if(cmd_buffers.begin(), cmd_buffers.end(),
|
|
212
|
+
[](const auto& cb) { return cb.in_use; });
|
|
213
|
+
}
|
|
201
214
|
};
|
|
202
215
|
|
|
203
216
|
// Prevent simultaneous submissions to the same queue.
|
|
@@ -254,6 +267,7 @@ enum vk_device_architecture {
|
|
|
254
267
|
AMD_RDNA3,
|
|
255
268
|
INTEL_XE2,
|
|
256
269
|
NVIDIA_PRE_TURING,
|
|
270
|
+
NVIDIA_TURING,
|
|
257
271
|
};
|
|
258
272
|
|
|
259
273
|
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
|
@@ -336,18 +350,34 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
|
|
|
336
350
|
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
|
337
351
|
|
|
338
352
|
bool cooperative_matrix = false;
|
|
353
|
+
bool sm_builtins = false;
|
|
339
354
|
|
|
340
355
|
// Detect "pre-turing" based on lack of coopmat support.
|
|
341
356
|
for (const auto& properties : ext_props) {
|
|
342
357
|
if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
|
|
343
358
|
cooperative_matrix = true;
|
|
344
|
-
|
|
359
|
+
} else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
|
|
360
|
+
sm_builtins = true;
|
|
345
361
|
}
|
|
346
362
|
}
|
|
347
363
|
|
|
348
364
|
if (!cooperative_matrix) {
|
|
349
365
|
return vk_device_architecture::NVIDIA_PRE_TURING;
|
|
350
366
|
}
|
|
367
|
+
|
|
368
|
+
if (sm_builtins) {
|
|
369
|
+
vk::PhysicalDeviceProperties2 props2;
|
|
370
|
+
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
|
|
371
|
+
|
|
372
|
+
props2.pNext = &sm_props;
|
|
373
|
+
|
|
374
|
+
device.getProperties2(&props2);
|
|
375
|
+
|
|
376
|
+
// Turing has 32, following architectures have 48
|
|
377
|
+
if (sm_props.shaderWarpsPerSM == 32) {
|
|
378
|
+
return vk_device_architecture::NVIDIA_TURING;
|
|
379
|
+
}
|
|
380
|
+
}
|
|
351
381
|
}
|
|
352
382
|
return vk_device_architecture::OTHER;
|
|
353
383
|
}
|
|
@@ -385,18 +415,20 @@ enum FaCodePath {
|
|
|
385
415
|
};
|
|
386
416
|
|
|
387
417
|
struct vk_fa_pipeline_state {
|
|
388
|
-
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
|
|
389
|
-
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
|
|
390
|
-
|
|
391
418
|
uint32_t HSK, HSV;
|
|
392
|
-
|
|
419
|
+
uint32_t Br, Bc;
|
|
420
|
+
uint32_t D_split, row_split;
|
|
421
|
+
bool shmem_staging;
|
|
393
422
|
FaCodePath path;
|
|
423
|
+
uint32_t workgroup_size, subgroup_size;
|
|
394
424
|
bool aligned;
|
|
395
425
|
bool f32acc;
|
|
426
|
+
uint32_t flags;
|
|
427
|
+
uint32_t limit_occupancy_shmem;
|
|
396
428
|
|
|
397
429
|
bool operator<(const vk_fa_pipeline_state &b) const {
|
|
398
|
-
return std::tie(HSK, HSV,
|
|
399
|
-
std::tie(b.HSK, b.HSV, b.
|
|
430
|
+
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
|
|
431
|
+
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
|
|
400
432
|
}
|
|
401
433
|
};
|
|
402
434
|
|
|
@@ -570,6 +602,7 @@ struct vk_device_struct {
|
|
|
570
602
|
vk_queue transfer_queue;
|
|
571
603
|
bool single_queue;
|
|
572
604
|
bool support_async;
|
|
605
|
+
bool async_use_transfer_queue;
|
|
573
606
|
uint32_t subgroup_size;
|
|
574
607
|
uint32_t subgroup_size_log2;
|
|
575
608
|
uint32_t shader_core_count;
|
|
@@ -669,6 +702,7 @@ struct vk_device_struct {
|
|
|
669
702
|
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
|
|
670
703
|
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
|
|
671
704
|
vk_pipeline pipeline_acc_f32;
|
|
705
|
+
vk_pipeline pipeline_set_f32;
|
|
672
706
|
|
|
673
707
|
// [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
|
|
674
708
|
vk_pipeline pipeline_add[2][2][2];
|
|
@@ -722,6 +756,7 @@ struct vk_device_struct {
|
|
|
722
756
|
|
|
723
757
|
// [src/dst 0=fp32,1=fp16]
|
|
724
758
|
vk_pipeline pipeline_exp[2];
|
|
759
|
+
vk_pipeline pipeline_elu[2];
|
|
725
760
|
vk_pipeline pipeline_gelu[2];
|
|
726
761
|
vk_pipeline pipeline_gelu_erf[2];
|
|
727
762
|
vk_pipeline pipeline_gelu_quick[2];
|
|
@@ -740,6 +775,7 @@ struct vk_device_struct {
|
|
|
740
775
|
vk_pipeline pipeline_ceil[2];
|
|
741
776
|
vk_pipeline pipeline_floor[2];
|
|
742
777
|
vk_pipeline pipeline_trunc[2];
|
|
778
|
+
vk_pipeline pipeline_sgn[2];
|
|
743
779
|
|
|
744
780
|
vk_pipeline pipeline_add1_f16_f16;
|
|
745
781
|
vk_pipeline pipeline_add1_f16_f32;
|
|
@@ -789,6 +825,8 @@ struct vk_device_struct {
|
|
|
789
825
|
vk_pipeline pipeline_pool2d_f32;
|
|
790
826
|
vk_pipeline pipeline_rwkv_wkv6_f32;
|
|
791
827
|
vk_pipeline pipeline_rwkv_wkv7_f32;
|
|
828
|
+
// [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128
|
|
829
|
+
vk_pipeline pipeline_gated_delta_net[3][2];
|
|
792
830
|
vk_pipeline pipeline_ssm_scan_f32_d128;
|
|
793
831
|
vk_pipeline pipeline_ssm_scan_f32_d256;
|
|
794
832
|
vk_pipeline pipeline_ssm_conv_f32;
|
|
@@ -803,6 +841,8 @@ struct vk_device_struct {
|
|
|
803
841
|
|
|
804
842
|
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
|
|
805
843
|
|
|
844
|
+
std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;
|
|
845
|
+
|
|
806
846
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
|
807
847
|
vk_pipeline pipeline_count_experts;
|
|
808
848
|
|
|
@@ -852,10 +892,12 @@ struct vk_device_struct {
|
|
|
852
892
|
};
|
|
853
893
|
|
|
854
894
|
void vk_command_pool::init(vk_device& device, vk_queue *q_) {
|
|
855
|
-
|
|
895
|
+
cmd_buffers.clear();
|
|
856
896
|
q = q_;
|
|
857
897
|
|
|
858
|
-
vk::CommandPoolCreateInfo command_pool_create_info(
|
|
898
|
+
vk::CommandPoolCreateInfo command_pool_create_info(
|
|
899
|
+
vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT | VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT),
|
|
900
|
+
q->queue_family_index);
|
|
859
901
|
pool = device->device.createCommandPool(command_pool_create_info);
|
|
860
902
|
}
|
|
861
903
|
|
|
@@ -903,6 +945,7 @@ struct vk_subbuffer {
|
|
|
903
945
|
struct vk_event {
|
|
904
946
|
vk::Event event;
|
|
905
947
|
vk::Fence fence;
|
|
948
|
+
vk_command_buffer* cmd_buffer = nullptr;
|
|
906
949
|
};
|
|
907
950
|
|
|
908
951
|
struct vk_semaphore {
|
|
@@ -911,7 +954,7 @@ struct vk_semaphore {
|
|
|
911
954
|
};
|
|
912
955
|
|
|
913
956
|
struct vk_submission {
|
|
914
|
-
|
|
957
|
+
vk_command_buffer* buffer = nullptr;
|
|
915
958
|
std::vector<vk_semaphore> wait_semaphores;
|
|
916
959
|
std::vector<vk_semaphore> signal_semaphores;
|
|
917
960
|
};
|
|
@@ -922,6 +965,7 @@ struct vk_mat_mat_push_constants {
|
|
|
922
965
|
uint32_t M; uint32_t N; uint32_t K;
|
|
923
966
|
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
|
924
967
|
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
|
|
968
|
+
uint32_t base_work_group_z; uint32_t num_batches;
|
|
925
969
|
uint32_t k_split;
|
|
926
970
|
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
|
|
927
971
|
uint32_t padded_N;
|
|
@@ -941,6 +985,7 @@ struct vk_mat_vec_push_constants {
|
|
|
941
985
|
uint32_t batch_stride_b;
|
|
942
986
|
uint32_t batch_stride_d;
|
|
943
987
|
uint32_t fusion_flags;
|
|
988
|
+
uint32_t base_work_group_y;
|
|
944
989
|
uint32_t ne02;
|
|
945
990
|
uint32_t ne12;
|
|
946
991
|
uint32_t broadcast2;
|
|
@@ -991,6 +1036,8 @@ struct vk_mat_vec_id_push_constants {
|
|
|
991
1036
|
uint32_t fusion_flags;
|
|
992
1037
|
uint32_t nei0;
|
|
993
1038
|
uint32_t ne11;
|
|
1039
|
+
uint32_t expert_i1;
|
|
1040
|
+
uint32_t nbi1;
|
|
994
1041
|
};
|
|
995
1042
|
|
|
996
1043
|
struct vk_flash_attn_push_constants {
|
|
@@ -1244,25 +1291,30 @@ struct vk_op_diag_mask_push_constants {
|
|
|
1244
1291
|
|
|
1245
1292
|
struct vk_op_rope_push_constants {
|
|
1246
1293
|
uint32_t rope_mode;
|
|
1247
|
-
uint32_t ncols;
|
|
1248
1294
|
uint32_t nrows;
|
|
1249
1295
|
uint32_t n_dims;
|
|
1250
1296
|
float freq_scale;
|
|
1251
|
-
uint32_t p_delta_rows;
|
|
1252
1297
|
float freq_base;
|
|
1253
1298
|
float ext_factor;
|
|
1254
1299
|
float attn_factor;
|
|
1255
1300
|
float corr_dims[2];
|
|
1256
1301
|
float theta_scale;
|
|
1257
1302
|
uint32_t has_ff;
|
|
1258
|
-
uint32_t ne02;
|
|
1259
|
-
uint32_t s1;
|
|
1260
|
-
uint32_t s2;
|
|
1261
1303
|
int32_t sections[4];
|
|
1262
1304
|
uint32_t is_imrope;
|
|
1263
1305
|
uint32_t is_back;
|
|
1264
1306
|
uint32_t set_rows_stride;
|
|
1307
|
+
uint32_t ne00;
|
|
1308
|
+
uint32_t ne01;
|
|
1309
|
+
uint32_t ne02;
|
|
1310
|
+
uint32_t nb01;
|
|
1311
|
+
uint32_t nb02;
|
|
1312
|
+
uint32_t nb03;
|
|
1313
|
+
uint32_t nb11;
|
|
1314
|
+
uint32_t nb12;
|
|
1315
|
+
uint32_t nb13;
|
|
1265
1316
|
};
|
|
1317
|
+
static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
|
|
1266
1318
|
|
|
1267
1319
|
// For fused rms_norm+mul+rope(+view+set_rows)
|
|
1268
1320
|
struct vk_op_rms_norm_mul_rope_push_constants {
|
|
@@ -1404,6 +1456,18 @@ struct vk_op_rwkv_wkv7_push_constants {
|
|
|
1404
1456
|
uint32_t C;
|
|
1405
1457
|
uint32_t H;
|
|
1406
1458
|
};
|
|
1459
|
+
struct vk_op_gated_delta_net_push_constants {
|
|
1460
|
+
uint32_t H;
|
|
1461
|
+
uint32_t n_tokens;
|
|
1462
|
+
uint32_t n_seqs;
|
|
1463
|
+
uint32_t s_off;
|
|
1464
|
+
uint32_t sq1, sq2, sq3;
|
|
1465
|
+
uint32_t sv1, sv2, sv3;
|
|
1466
|
+
uint32_t sb1, sb2, sb3;
|
|
1467
|
+
uint32_t neq1, rq3;
|
|
1468
|
+
float scale;
|
|
1469
|
+
};
|
|
1470
|
+
|
|
1407
1471
|
struct vk_op_ssm_scan_push_constants {
|
|
1408
1472
|
uint32_t nb02, nb03, nb12, nb13;
|
|
1409
1473
|
uint32_t nb21, nb22, nb31;
|
|
@@ -1516,6 +1580,27 @@ struct vk_quantize_q8_1_push_constants {
|
|
|
1516
1580
|
uint32_t num_blocks;
|
|
1517
1581
|
};
|
|
1518
1582
|
|
|
1583
|
+
struct vk_op_flash_attn_split_k_reduce_push_constants {
|
|
1584
|
+
uint32_t D;
|
|
1585
|
+
uint32_t ne1;
|
|
1586
|
+
uint32_t ne2;
|
|
1587
|
+
uint32_t ne3;
|
|
1588
|
+
uint32_t k_num;
|
|
1589
|
+
uint32_t sinks;
|
|
1590
|
+
};
|
|
1591
|
+
|
|
1592
|
+
struct vk_op_flash_attn_mask_opt_push_constants {
|
|
1593
|
+
uint32_t nem0;
|
|
1594
|
+
uint32_t nem1;
|
|
1595
|
+
uint32_t nem2;
|
|
1596
|
+
uint32_t nbm1;
|
|
1597
|
+
uint32_t nbm2;
|
|
1598
|
+
uint32_t nbm3;
|
|
1599
|
+
uint32_t nbd1;
|
|
1600
|
+
uint32_t nbd2;
|
|
1601
|
+
uint32_t nbd3;
|
|
1602
|
+
};
|
|
1603
|
+
|
|
1519
1604
|
// Allow pre-recording command buffers
|
|
1520
1605
|
struct vk_staging_memcpy {
|
|
1521
1606
|
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
|
@@ -1604,6 +1689,7 @@ static bool vk_perf_logger_concurrent = false;
|
|
|
1604
1689
|
static bool vk_enable_sync_logger = false;
|
|
1605
1690
|
// number of calls between perf logger prints
|
|
1606
1691
|
static uint32_t vk_perf_logger_frequency = 1;
|
|
1692
|
+
static std::string vk_pipeline_stats_filter;
|
|
1607
1693
|
|
|
1608
1694
|
class vk_perf_logger {
|
|
1609
1695
|
public:
|
|
@@ -1724,6 +1810,7 @@ class vk_perf_logger {
|
|
|
1724
1810
|
" k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
|
|
1725
1811
|
" v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
|
|
1726
1812
|
" m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
|
|
1813
|
+
*n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
|
|
1727
1814
|
return name.str();
|
|
1728
1815
|
}
|
|
1729
1816
|
if (node->op == GGML_OP_TOP_K) {
|
|
@@ -1802,7 +1889,10 @@ struct ggml_backend_vk_context {
|
|
|
1802
1889
|
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
|
|
1803
1890
|
|
|
1804
1891
|
vk_context_ref compute_ctx;
|
|
1892
|
+
|
|
1805
1893
|
vk_context_ref transfer_ctx;
|
|
1894
|
+
vk_semaphore transfer_semaphore;
|
|
1895
|
+
uint64_t transfer_semaphore_last_submitted {};
|
|
1806
1896
|
|
|
1807
1897
|
std::vector<vk_context_ref> tensor_ctxs;
|
|
1808
1898
|
|
|
@@ -2121,7 +2211,32 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
|
|
2121
2211
|
executableInfo.pipeline = pipeline->pipeline;
|
|
2122
2212
|
|
|
2123
2213
|
auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo);
|
|
2214
|
+
|
|
2215
|
+
bool print_stats = !vk_pipeline_stats_filter.empty() &&
|
|
2216
|
+
pipeline->name.find(vk_pipeline_stats_filter) != std::string::npos;
|
|
2217
|
+
if (print_stats) {
|
|
2218
|
+
std::cerr << "ggml_vulkan: pipeline stats for " << pipeline->name << ":" << std::endl;
|
|
2219
|
+
}
|
|
2220
|
+
|
|
2124
2221
|
for (auto & s : statistics) {
|
|
2222
|
+
if (print_stats) {
|
|
2223
|
+
std::cerr << "ggml_vulkan: " << s.name.data() << ": ";
|
|
2224
|
+
switch (s.format) {
|
|
2225
|
+
case vk::PipelineExecutableStatisticFormatKHR::eBool32:
|
|
2226
|
+
std::cerr << (s.value.b32 ? "true" : "false");
|
|
2227
|
+
break;
|
|
2228
|
+
case vk::PipelineExecutableStatisticFormatKHR::eInt64:
|
|
2229
|
+
std::cerr << s.value.i64;
|
|
2230
|
+
break;
|
|
2231
|
+
case vk::PipelineExecutableStatisticFormatKHR::eUint64:
|
|
2232
|
+
std::cerr << s.value.u64;
|
|
2233
|
+
break;
|
|
2234
|
+
case vk::PipelineExecutableStatisticFormatKHR::eFloat64:
|
|
2235
|
+
std::cerr << s.value.f64;
|
|
2236
|
+
break;
|
|
2237
|
+
}
|
|
2238
|
+
std::cerr << std::endl;
|
|
2239
|
+
}
|
|
2125
2240
|
// "Register Count" is reported by NVIDIA drivers.
|
|
2126
2241
|
if (strcmp(s.name, "Register Count") == 0) {
|
|
2127
2242
|
VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers");
|
|
@@ -2197,25 +2312,15 @@ static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx
|
|
|
2197
2312
|
}
|
|
2198
2313
|
}
|
|
2199
2314
|
|
|
2200
|
-
static
|
|
2315
|
+
static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {
|
|
2201
2316
|
VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
|
|
2202
|
-
|
|
2203
|
-
if (p.cmd_buffers.size() > p.cmd_buffer_idx) {
|
|
2204
|
-
// Reuse command buffer
|
|
2205
|
-
return p.cmd_buffers[p.cmd_buffer_idx++];
|
|
2206
|
-
}
|
|
2207
|
-
|
|
2208
2317
|
vk::CommandBufferAllocateInfo command_buffer_alloc_info(
|
|
2209
2318
|
p.pool,
|
|
2210
2319
|
vk::CommandBufferLevel::ePrimary,
|
|
2211
2320
|
1);
|
|
2212
2321
|
const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
|
|
2213
|
-
|
|
2214
|
-
|
|
2215
|
-
p.cmd_buffers.push_back(buf);
|
|
2216
|
-
p.cmd_buffer_idx++;
|
|
2217
|
-
|
|
2218
|
-
return buf;
|
|
2322
|
+
p.cmd_buffers.push_back({ cmd_buffers.front(), true });
|
|
2323
|
+
return &p.cmd_buffers[p.cmd_buffers.size()-1];
|
|
2219
2324
|
}
|
|
2220
2325
|
|
|
2221
2326
|
static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
|
|
@@ -2282,7 +2387,7 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
|
|
|
2282
2387
|
tl_wait_semaphores[idx].data(),
|
|
2283
2388
|
stage_flags[idx].data(),
|
|
2284
2389
|
1,
|
|
2285
|
-
&submission.buffer,
|
|
2390
|
+
&submission.buffer->buf,
|
|
2286
2391
|
(uint32_t) submission.signal_semaphores.size(),
|
|
2287
2392
|
tl_signal_semaphores[idx].data(),
|
|
2288
2393
|
};
|
|
@@ -2406,7 +2511,11 @@ static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p)
|
|
|
2406
2511
|
|
|
2407
2512
|
// Requires command buffers to be done
|
|
2408
2513
|
device->device.resetCommandPool(p.pool);
|
|
2409
|
-
|
|
2514
|
+
// Don't clear the command buffers and mark them as not in use.
|
|
2515
|
+
// This allows us to reuse them
|
|
2516
|
+
for (auto& cmd_buffer : p.cmd_buffers) {
|
|
2517
|
+
cmd_buffer.in_use = false;
|
|
2518
|
+
}
|
|
2410
2519
|
}
|
|
2411
2520
|
|
|
2412
2521
|
static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
|
|
@@ -2415,10 +2524,10 @@ static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
|
|
|
2415
2524
|
// Arbitrary frequency to cleanup/reuse command buffers
|
|
2416
2525
|
static constexpr uint32_t cleanup_frequency = 10;
|
|
2417
2526
|
|
|
2418
|
-
if (device->compute_queue.cmd_pool.
|
|
2527
|
+
if (device->compute_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) {
|
|
2419
2528
|
ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool);
|
|
2420
2529
|
}
|
|
2421
|
-
if (device->transfer_queue.cmd_pool.
|
|
2530
|
+
if (device->transfer_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) {
|
|
2422
2531
|
ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool);
|
|
2423
2532
|
}
|
|
2424
2533
|
}
|
|
@@ -2666,7 +2775,7 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
|
|
|
2666
2775
|
ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
|
|
2667
2776
|
}
|
|
2668
2777
|
|
|
2669
|
-
subctx->s->buffer.pipelineBarrier(
|
|
2778
|
+
subctx->s->buffer->buf.pipelineBarrier(
|
|
2670
2779
|
subctx->p->q->stage_flags,
|
|
2671
2780
|
subctx->p->q->stage_flags,
|
|
2672
2781
|
{},
|
|
@@ -2682,7 +2791,7 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
|
|
|
2682
2791
|
static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
|
|
2683
2792
|
VK_LOG_DEBUG("ggml_vk_set_event()");
|
|
2684
2793
|
|
|
2685
|
-
ctx->s->buffer.setEvent(
|
|
2794
|
+
ctx->s->buffer->buf.setEvent(
|
|
2686
2795
|
event,
|
|
2687
2796
|
ctx->p->q->stage_flags
|
|
2688
2797
|
);
|
|
@@ -2694,7 +2803,7 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
|
|
2694
2803
|
return;
|
|
2695
2804
|
}
|
|
2696
2805
|
|
|
2697
|
-
ctx->s->buffer.waitEvents(
|
|
2806
|
+
ctx->s->buffer->buf.waitEvents(
|
|
2698
2807
|
events,
|
|
2699
2808
|
ctx->p->q->stage_flags,
|
|
2700
2809
|
ctx->p->q->stage_flags,
|
|
@@ -2704,78 +2813,218 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
|
|
2704
2813
|
);
|
|
2705
2814
|
}
|
|
2706
2815
|
|
|
2707
|
-
|
|
2708
|
-
|
|
2709
|
-
|
|
2816
|
+
struct vk_fa_tuning_params {
|
|
2817
|
+
FaCodePath path;
|
|
2818
|
+
uint32_t workgroup_size;
|
|
2819
|
+
uint32_t subgroup_size;
|
|
2820
|
+
uint32_t block_rows;
|
|
2821
|
+
uint32_t block_cols;
|
|
2822
|
+
uint32_t d_split;
|
|
2823
|
+
uint32_t row_split;
|
|
2824
|
+
bool shmem_staging;
|
|
2825
|
+
bool disable_subgroups;
|
|
2826
|
+
uint32_t limit_occupancy_shmem;
|
|
2827
|
+
|
|
2828
|
+
void print() const {
|
|
2829
|
+
std::cerr << "path=" << path << " workgroup_size=" << workgroup_size << " subgroup_size=" << subgroup_size <<
|
|
2830
|
+
" block_rows=" << block_rows << " block_cols=" << block_cols << " d_split=" << d_split <<
|
|
2831
|
+
" row_split=" << row_split << " shmem_staging=" << shmem_staging << " disable_subgroups=" << disable_subgroups <<
|
|
2832
|
+
" limit_occupancy_shmem=" << limit_occupancy_shmem << std::endl;
|
|
2833
|
+
}
|
|
2834
|
+
};
|
|
2835
|
+
|
|
2836
|
+
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
|
2837
|
+
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
|
2710
2838
|
|
|
2711
|
-
static
|
|
2712
|
-
|
|
2713
|
-
|
|
2714
|
-
|
|
2715
|
-
|
|
2839
|
+
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
|
2840
|
+
GGML_UNUSED(kv_type);
|
|
2841
|
+
|
|
2842
|
+
vk_fa_tuning_params result{};
|
|
2843
|
+
result.path = FA_SCALAR;
|
|
2844
|
+
|
|
2845
|
+
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
|
|
2846
|
+
// Disable subgroup use due to performance issues when enforcing subgroup sizes
|
|
2847
|
+
result.subgroup_size = 32;
|
|
2848
|
+
result.disable_subgroups = true;
|
|
2849
|
+
} else if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) {
|
|
2850
|
+
result.subgroup_size = n_rows < 4 ? 32 : device->subgroup_size;
|
|
2716
2851
|
} else {
|
|
2717
|
-
|
|
2852
|
+
result.subgroup_size = device->subgroup_size;
|
|
2718
2853
|
}
|
|
2719
|
-
}
|
|
2720
2854
|
|
|
2721
|
-
//
|
|
2722
|
-
|
|
2723
|
-
|
|
2724
|
-
|
|
2725
|
-
|
|
2726
|
-
|
|
2855
|
+
// Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers
|
|
2856
|
+
uint32_t row_split_max_hsk = 64;
|
|
2857
|
+
if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !device->uma) {
|
|
2858
|
+
row_split_max_hsk = n_rows <= 8 ? 64 : 128;
|
|
2859
|
+
}
|
|
2860
|
+
result.row_split = (n_rows < 4 || hsk <= row_split_max_hsk) ? 1 : 4;
|
|
2727
2861
|
|
|
2728
|
-
|
|
2729
|
-
|
|
2730
|
-
return flash_attention_num_small_rows;
|
|
2862
|
+
if (result.subgroup_size > 32 && (n_rows < 4 || hsk < (result.row_split == 1 ? 128 : 64))) {
|
|
2863
|
+
result.workgroup_size = result.subgroup_size * 2;
|
|
2731
2864
|
} else {
|
|
2732
|
-
|
|
2865
|
+
result.workgroup_size = result.subgroup_size * 4;
|
|
2733
2866
|
}
|
|
2734
|
-
}
|
|
2735
2867
|
|
|
2736
|
-
|
|
2737
|
-
GGML_UNUSED(clamp);
|
|
2868
|
+
const uint32_t D = hsk | hsv;
|
|
2738
2869
|
|
|
2739
|
-
|
|
2740
|
-
|
|
2741
|
-
|
|
2870
|
+
const bool reduce_block_rows = D & 8 || n_kv < 1024 || device->vendor_id == VK_VENDOR_ID_INTEL;
|
|
2871
|
+
|
|
2872
|
+
if (n_rows == 1) {
|
|
2873
|
+
result.block_rows = 1;
|
|
2874
|
+
result.block_cols = 64;
|
|
2875
|
+
} else {
|
|
2876
|
+
// row_split 1 means higher register use per row, so block size has to be adjusted
|
|
2877
|
+
if (result.row_split == 1) {
|
|
2878
|
+
result.block_rows = n_rows == 2 ? 2 : ((n_rows <= 4 || reduce_block_rows) ? 4 : 8);
|
|
2742
2879
|
} else {
|
|
2743
|
-
|
|
2744
|
-
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
|
|
2745
|
-
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
|
|
2746
|
-
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
|
|
2747
|
-
} else {
|
|
2748
|
-
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
|
|
2749
|
-
}
|
|
2880
|
+
result.block_rows = n_rows <= 4 ? 4 : ((n_rows <= 8 || reduce_block_rows) ? 8 : 16);
|
|
2750
2881
|
}
|
|
2882
|
+
|
|
2883
|
+
result.block_cols = (D & 8) ? 64 : 32;
|
|
2751
2884
|
}
|
|
2752
2885
|
|
|
2753
|
-
|
|
2754
|
-
|
|
2755
|
-
|
|
2756
|
-
|
|
2757
|
-
|
|
2886
|
+
const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit
|
|
2887
|
+
|
|
2888
|
+
result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
|
|
2889
|
+
|
|
2890
|
+
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
|
|
2891
|
+
|
|
2892
|
+
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
|
|
2893
|
+
result.block_rows /= 2;
|
|
2894
|
+
}
|
|
2895
|
+
|
|
2896
|
+
// On AMD RDNA, for small head sizes and big batch size the shader uses few registers, so too many subgroups get scheduled
|
|
2897
|
+
// at once and end up thrashing the cache. Fix this by setting a large (unused) shmem buffer that reduces occupancy.
|
|
2898
|
+
// This targets an occupancy of 4 subgroups per SIMD.
|
|
2899
|
+
if (device->vendor_id == VK_VENDOR_ID_AMD && device->properties.limits.maxComputeSharedMemorySize == 65536) {
|
|
2900
|
+
if (device->architecture != AMD_GCN && n_rows >= 64 && hsk <= 128) {
|
|
2901
|
+
// 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size
|
|
2902
|
+
// Values are guessed, tested on RDNA2
|
|
2903
|
+
result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4;
|
|
2904
|
+
} else if (device->architecture == AMD_GCN && n_rows <= 8 && hsk >= 256) {
|
|
2905
|
+
// Same thing for GCN, with an occupancy target of 2 subgroups per SIMD.
|
|
2906
|
+
// Here low-batch FA with large head size is affected.
|
|
2907
|
+
// n_rows < 4 switch because workgroup size switches from 128 to 256 there.
|
|
2908
|
+
result.limit_occupancy_shmem = (n_rows < 4 ? 14 : 26) * 1024 / 4 / 4;
|
|
2758
2909
|
}
|
|
2759
2910
|
}
|
|
2760
2911
|
|
|
2761
|
-
|
|
2912
|
+
return result;
|
|
2913
|
+
}
|
|
2914
|
+
|
|
2915
|
+
static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
|
2916
|
+
GGML_UNUSED(n_rows);
|
|
2917
|
+
GGML_UNUSED(n_kv);
|
|
2918
|
+
GGML_UNUSED(kv_type);
|
|
2919
|
+
GGML_UNUSED(f32acc);
|
|
2920
|
+
|
|
2921
|
+
vk_fa_tuning_params result{};
|
|
2922
|
+
result.path = FA_COOPMAT1;
|
|
2923
|
+
|
|
2924
|
+
const uint32_t D = hsk | hsv;
|
|
2925
|
+
|
|
2926
|
+
const uint32_t coopmat_block_rows = 16;
|
|
2927
|
+
const uint32_t coopmat_block_cols = 16;
|
|
2928
|
+
|
|
2929
|
+
const uint32_t num_subgroups = 4;
|
|
2930
|
+
|
|
2931
|
+
result.block_rows = coopmat_block_rows;
|
|
2932
|
+
result.block_cols = coopmat_block_cols * num_subgroups;
|
|
2933
|
+
result.row_split = num_subgroups;
|
|
2934
|
+
result.subgroup_size = device->subgroup_size;
|
|
2935
|
+
result.workgroup_size = num_subgroups * result.subgroup_size;
|
|
2936
|
+
|
|
2937
|
+
const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit
|
|
2938
|
+
result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
|
|
2939
|
+
|
|
2940
|
+
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
|
|
2941
|
+
|
|
2942
|
+
return result;
|
|
2943
|
+
}
|
|
2944
|
+
|
|
2945
|
+
static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
|
2946
|
+
GGML_UNUSED(n_kv);
|
|
2947
|
+
GGML_UNUSED(f32acc);
|
|
2948
|
+
|
|
2949
|
+
vk_fa_tuning_params result{};
|
|
2950
|
+
result.path = FA_COOPMAT2;
|
|
2951
|
+
|
|
2952
|
+
const uint32_t D = hsk | hsv;
|
|
2953
|
+
|
|
2954
|
+
const bool small_rows = n_rows < 32;
|
|
2955
|
+
|
|
2762
2956
|
if (small_rows) {
|
|
2763
|
-
|
|
2957
|
+
result.block_rows = 32;
|
|
2958
|
+
result.block_cols = 32;
|
|
2959
|
+
} else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {
|
|
2960
|
+
result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
|
|
2961
|
+
result.block_cols = 32;
|
|
2962
|
+
} else {
|
|
2963
|
+
result.block_rows = 64;
|
|
2964
|
+
result.block_cols = 64;
|
|
2764
2965
|
}
|
|
2765
2966
|
|
|
2766
|
-
|
|
2767
|
-
|
|
2768
|
-
|
|
2769
|
-
|
|
2770
|
-
|
|
2771
|
-
|
|
2967
|
+
result.subgroup_size = device->subgroup_size;
|
|
2968
|
+
result.workgroup_size = (small_rows && (D % 32) == 0) ? 256 : 128;
|
|
2969
|
+
|
|
2970
|
+
return result;
|
|
2971
|
+
}
|
|
2972
|
+
|
|
2973
|
+
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
|
2974
|
+
FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
|
|
2975
|
+
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
|
2976
|
+
|
|
2977
|
+
if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) {
|
|
2978
|
+
// Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
|
|
2979
|
+
path = FA_SCALAR;
|
|
2980
|
+
}
|
|
2981
|
+
|
|
2982
|
+
if (path == FA_COOPMAT1) {
|
|
2983
|
+
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
|
|
2984
|
+
(!f32acc && device->coopmat_support_16x16x16_f16acc);
|
|
2985
|
+
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
|
2986
|
+
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
|
|
2987
|
+
|
|
2988
|
+
if (!shape_ok || !shmem_ok) {
|
|
2989
|
+
path = FA_SCALAR;
|
|
2772
2990
|
}
|
|
2773
2991
|
}
|
|
2774
|
-
|
|
2992
|
+
|
|
2993
|
+
// scalar is faster than coopmat when N==1
|
|
2994
|
+
if (n_rows == 1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) {
|
|
2995
|
+
path = FA_SCALAR;
|
|
2996
|
+
}
|
|
2997
|
+
|
|
2998
|
+
switch (path) {
|
|
2999
|
+
case FA_SCALAR:
|
|
3000
|
+
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
|
3001
|
+
case FA_COOPMAT1:
|
|
3002
|
+
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
|
3003
|
+
case FA_COOPMAT2:
|
|
3004
|
+
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
|
3005
|
+
default:
|
|
3006
|
+
throw std::runtime_error("unsupported FaCodePath");
|
|
3007
|
+
}
|
|
2775
3008
|
}
|
|
2776
3009
|
|
|
2777
|
-
static
|
|
2778
|
-
|
|
3010
|
+
static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
|
|
3011
|
+
bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
|
|
3012
|
+
const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&
|
|
3013
|
+
(device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);
|
|
3014
|
+
|
|
3015
|
+
uint32_t flags = (use_mask_opt ? 1 : 0) |
|
|
3016
|
+
(use_mask ? 2 : 0) |
|
|
3017
|
+
(use_logit_softcap ? 4 : 0) |
|
|
3018
|
+
(old_amd_windows ? 8 : 0);
|
|
3019
|
+
|
|
3020
|
+
const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
|
|
3021
|
+
|
|
3022
|
+
return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
|
|
3023
|
+
}
|
|
3024
|
+
|
|
3025
|
+
static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {
|
|
3026
|
+
return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,
|
|
3027
|
+
state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};
|
|
2779
3028
|
}
|
|
2780
3029
|
|
|
2781
3030
|
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
|
@@ -3142,60 +3391,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3142
3391
|
align, disable_robustness, require_full_subgroups, required_subgroup_size);
|
|
3143
3392
|
};
|
|
3144
3393
|
|
|
3145
|
-
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> {
|
|
3146
|
-
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
|
|
3147
|
-
};
|
|
3148
|
-
|
|
3149
|
-
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector<uint32_t> {
|
|
3150
|
-
// For large number of rows, 128 invocations seems to work best.
|
|
3151
|
-
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
|
3152
|
-
// can't use 256 for D==80.
|
|
3153
|
-
// For scalar, use 128 (arbitrary)
|
|
3154
|
-
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
|
|
3155
|
-
const uint32_t D = (hsk|hsv);
|
|
3156
|
-
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
|
3157
|
-
? scalar_flash_attention_workgroup_size
|
|
3158
|
-
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
|
3159
|
-
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
|
|
3160
|
-
|
|
3161
|
-
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
|
3162
|
-
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
|
3163
|
-
const uint32_t D_lsb = D ^ (D & (D-1));
|
|
3164
|
-
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
|
|
3165
|
-
|
|
3166
|
-
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
|
|
3167
|
-
};
|
|
3168
|
-
|
|
3169
3394
|
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
|
3170
3395
|
for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
|
|
3171
|
-
uint32_t HSK = fa.first.HSK; \
|
|
3172
|
-
uint32_t HSV = fa.first.HSV; \
|
|
3173
|
-
bool small_rows = fa.first.small_rows; \
|
|
3174
|
-
bool small_cache = fa.first.small_cache; \
|
|
3175
3396
|
FaCodePath path = fa.first.path; \
|
|
3397
|
+
uint32_t Br = fa.first.Br; \
|
|
3398
|
+
uint32_t Bc = fa.first.Bc; \
|
|
3176
3399
|
bool aligned = fa.first.aligned; \
|
|
3177
3400
|
bool f32acc = fa.first.f32acc; \
|
|
3401
|
+
uint32_t fa_sgs = fa.first.subgroup_size; \
|
|
3402
|
+
bool fa_ds = fa.first.subgroup_size == 0; \
|
|
3178
3403
|
if (path == FAPATH) { \
|
|
3179
3404
|
if (aligned) { \
|
|
3180
3405
|
if (f32acc) { \
|
|
3181
|
-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main",
|
|
3406
|
+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
|
3182
3407
|
} else { \
|
|
3183
|
-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main",
|
|
3408
|
+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
|
3184
3409
|
} \
|
|
3185
3410
|
} else { \
|
|
3186
3411
|
if (f32acc) { \
|
|
3187
|
-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main",
|
|
3412
|
+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
|
3188
3413
|
} else { \
|
|
3189
|
-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main",
|
|
3414
|
+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
|
3190
3415
|
} \
|
|
3191
3416
|
} \
|
|
3192
3417
|
} \
|
|
3193
3418
|
}
|
|
3194
3419
|
|
|
3195
|
-
|
|
3196
|
-
|
|
3197
|
-
|
|
3198
|
-
|
|
3420
|
+
if (device->fp16) {
|
|
3421
|
+
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
|
3422
|
+
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
|
3423
|
+
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
|
3424
|
+
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
|
3425
|
+
} else {
|
|
3426
|
+
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
|
|
3427
|
+
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
|
|
3428
|
+
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
|
3429
|
+
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
|
3430
|
+
}
|
|
3199
3431
|
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
3200
3432
|
if (device->coopmat1_fa_support) {
|
|
3201
3433
|
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
|
|
@@ -3713,10 +3945,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3713
3945
|
&& !device->coopmat_bf16_support
|
|
3714
3946
|
#endif
|
|
3715
3947
|
) {
|
|
3948
|
+
const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
|
|
3949
|
+
|
|
3716
3950
|
// use scalar tile sizes
|
|
3717
3951
|
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
|
3718
3952
|
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
|
|
3719
|
-
s_warptile = {
|
|
3953
|
+
s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, 2, 2, 1, subgroup_size_8 };
|
|
3720
3954
|
|
|
3721
3955
|
l_wg_denoms = {128, 128, 1 };
|
|
3722
3956
|
m_wg_denoms = { 64, 64, 1 };
|
|
@@ -3980,7 +4214,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3980
4214
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
3981
4215
|
|
|
3982
4216
|
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
|
3983
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3,
|
|
4217
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
|
4218
|
+
|
|
4219
|
+
for (auto &it : device->pipeline_fa_mask_opt) {
|
|
4220
|
+
auto BrBc = it.first;
|
|
4221
|
+
ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size);
|
|
4222
|
+
}
|
|
3984
4223
|
|
|
3985
4224
|
if (device->subgroup_clustered && device->subgroup_require_full_support) {
|
|
3986
4225
|
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
|
|
@@ -4012,7 +4251,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4012
4251
|
}
|
|
4013
4252
|
|
|
4014
4253
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
4015
|
-
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(
|
|
4254
|
+
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
|
4016
4255
|
|
|
4017
4256
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
4018
4257
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -4113,7 +4352,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4113
4352
|
|
|
4114
4353
|
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
|
|
4115
4354
|
|
|
4116
|
-
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
4355
|
+
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
|
|
4356
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);
|
|
4117
4357
|
|
|
4118
4358
|
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
4119
4359
|
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -4158,6 +4398,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4158
4398
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
|
4159
4399
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
4160
4400
|
|
|
4401
|
+
CREATE_UNARY(elu)
|
|
4161
4402
|
CREATE_UNARY(gelu)
|
|
4162
4403
|
CREATE_UNARY(gelu_erf)
|
|
4163
4404
|
CREATE_UNARY(gelu_quick)
|
|
@@ -4176,6 +4417,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4176
4417
|
CREATE_UNARY(ceil)
|
|
4177
4418
|
CREATE_UNARY(floor)
|
|
4178
4419
|
CREATE_UNARY(trunc)
|
|
4420
|
+
CREATE_UNARY(sgn)
|
|
4179
4421
|
#undef CREATE_UNARY
|
|
4180
4422
|
|
|
4181
4423
|
#define CREATE_UNARY_RTE(name) \
|
|
@@ -4340,6 +4582,23 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4340
4582
|
|
|
4341
4583
|
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
|
4342
4584
|
|
|
4585
|
+
{
|
|
4586
|
+
const uint32_t gdn_sizes[] = {32, 64, 128};
|
|
4587
|
+
const char * gdn_names[][2] = {
|
|
4588
|
+
{"gated_delta_net_f32_d32", "gated_delta_net_f32_d32_kda"},
|
|
4589
|
+
{"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
|
|
4590
|
+
{"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
|
|
4591
|
+
};
|
|
4592
|
+
for (uint32_t si = 0; si < 3; si++) {
|
|
4593
|
+
for (uint32_t kda = 0; kda < 2; kda++) {
|
|
4594
|
+
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
|
|
4595
|
+
gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
|
|
4596
|
+
"main", 7, sizeof(vk_op_gated_delta_net_push_constants),
|
|
4597
|
+
{1, 1, 1}, {gdn_sizes[si], kda}, 1);
|
|
4598
|
+
}
|
|
4599
|
+
}
|
|
4600
|
+
}
|
|
4601
|
+
|
|
4343
4602
|
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
|
4344
4603
|
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
|
|
4345
4604
|
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
|
|
@@ -4348,7 +4607,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4348
4607
|
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
|
4349
4608
|
}
|
|
4350
4609
|
|
|
4351
|
-
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32,
|
|
4610
|
+
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
|
|
4352
4611
|
|
|
4353
4612
|
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
4354
4613
|
|
|
@@ -4460,6 +4719,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4460
4719
|
}
|
|
4461
4720
|
|
|
4462
4721
|
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
|
|
4722
|
+
static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev);
|
|
4463
4723
|
|
|
4464
4724
|
static vk_device ggml_vk_get_device(size_t idx) {
|
|
4465
4725
|
VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
|
|
@@ -4676,6 +4936,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
4676
4936
|
device->shader_core_count = sm_props.shaderSMCount;
|
|
4677
4937
|
} else if (amd_shader_core_properties2) {
|
|
4678
4938
|
device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
|
|
4939
|
+
} else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
|
|
4940
|
+
device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device);
|
|
4679
4941
|
} else {
|
|
4680
4942
|
device->shader_core_count = 0;
|
|
4681
4943
|
}
|
|
@@ -4719,8 +4981,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
4719
4981
|
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
|
4720
4982
|
|
|
4721
4983
|
// Try to find a non-graphics compute queue and transfer-focused queues
|
|
4722
|
-
|
|
4723
|
-
const
|
|
4984
|
+
// On AMD, the graphics queue seems to be faster, so don't avoid it
|
|
4985
|
+
const vk::QueueFlagBits graphics_flag = device->vendor_id == VK_VENDOR_ID_AMD ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics;
|
|
4986
|
+
const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, graphics_flag, -1, 1);
|
|
4987
|
+
const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | graphics_flag, compute_queue_family_index, 1);
|
|
4724
4988
|
|
|
4725
4989
|
const float priorities[] = { 1.0f, 1.0f };
|
|
4726
4990
|
device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
|
|
@@ -4895,11 +5159,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
4895
5159
|
|
|
4896
5160
|
#if defined(VK_KHR_cooperative_matrix)
|
|
4897
5161
|
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
|
|
4898
|
-
|
|
4899
|
-
// coopmat1 fa shader currently assumes 32 invocations per subgroup
|
|
4900
|
-
device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
|
|
4901
|
-
device->subgroup_size_control && device->subgroup_min_size <= 32 &&
|
|
4902
|
-
device->subgroup_max_size >= 32;
|
|
5162
|
+
device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support;
|
|
4903
5163
|
#endif
|
|
4904
5164
|
|
|
4905
5165
|
if (coopmat2_support) {
|
|
@@ -5186,10 +5446,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
5186
5446
|
if (!device->single_queue) {
|
|
5187
5447
|
const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
|
|
5188
5448
|
ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
|
|
5449
|
+
|
|
5450
|
+
device->async_use_transfer_queue = (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr);
|
|
5189
5451
|
} else {
|
|
5190
5452
|
// TODO: Use pointer or reference to avoid copy
|
|
5191
5453
|
device->transfer_queue.copyFrom(device->compute_queue);
|
|
5192
5454
|
device->transfer_queue.cmd_pool.init(device, &device->transfer_queue);
|
|
5455
|
+
|
|
5456
|
+
device->async_use_transfer_queue = false;
|
|
5193
5457
|
}
|
|
5194
5458
|
|
|
5195
5459
|
device->buffer_type = {
|
|
@@ -5467,6 +5731,10 @@ static void ggml_vk_instance_init() {
|
|
|
5467
5731
|
vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr;
|
|
5468
5732
|
vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr;
|
|
5469
5733
|
vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr;
|
|
5734
|
+
const char* GGML_VK_PIPELINE_STATS = getenv("GGML_VK_PIPELINE_STATS");
|
|
5735
|
+
if (GGML_VK_PIPELINE_STATS != nullptr) {
|
|
5736
|
+
vk_pipeline_stats_filter = GGML_VK_PIPELINE_STATS;
|
|
5737
|
+
}
|
|
5470
5738
|
const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY");
|
|
5471
5739
|
|
|
5472
5740
|
if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {
|
|
@@ -5513,22 +5781,30 @@ static void ggml_vk_instance_init() {
|
|
|
5513
5781
|
|
|
5514
5782
|
if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) {
|
|
5515
5783
|
// Check if there are two physical devices corresponding to the same GPU
|
|
5784
|
+
// This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux),
|
|
5785
|
+
// see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication.
|
|
5786
|
+
// MoltenVK on macOS may report the same UUID for distinct GPUs on multi-GPU cards,
|
|
5787
|
+
// see https://github.com/KhronosGroup/MoltenVK/issues/2683. Skip when both old/new
|
|
5788
|
+
// driver is MoltenVK
|
|
5516
5789
|
auto old_device = std::find_if(
|
|
5517
5790
|
vk_instance.device_indices.begin(),
|
|
5518
5791
|
vk_instance.device_indices.end(),
|
|
5519
|
-
[&devices, &new_id](const size_t k){
|
|
5792
|
+
[&devices, &new_id, &new_driver](const size_t k){
|
|
5520
5793
|
vk::PhysicalDeviceProperties2 old_props;
|
|
5794
|
+
vk::PhysicalDeviceDriverProperties old_driver;
|
|
5521
5795
|
vk::PhysicalDeviceIDProperties old_id;
|
|
5522
|
-
old_props.pNext = &
|
|
5796
|
+
old_props.pNext = &old_driver;
|
|
5797
|
+
old_driver.pNext = &old_id;
|
|
5523
5798
|
devices[k].getProperties2(&old_props);
|
|
5524
5799
|
|
|
5525
|
-
bool
|
|
5526
|
-
|
|
5800
|
+
bool same_uuid = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
|
|
5801
|
+
same_uuid = same_uuid || (
|
|
5527
5802
|
old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
|
|
5528
5803
|
std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
|
|
5529
5804
|
);
|
|
5805
|
+
bool both_molten_vk = (new_driver.driverID == vk::DriverId::eMoltenvk && old_driver.driverID == vk::DriverId::eMoltenvk);
|
|
5530
5806
|
|
|
5531
|
-
return
|
|
5807
|
+
return same_uuid && !both_molten_vk;
|
|
5532
5808
|
}
|
|
5533
5809
|
);
|
|
5534
5810
|
if (old_device == vk_instance.device_indices.end()) {
|
|
@@ -5565,6 +5841,10 @@ static void ggml_vk_instance_init() {
|
|
|
5565
5841
|
driver_priorities[vk::DriverId::eMesaNvk] = 2;
|
|
5566
5842
|
#endif
|
|
5567
5843
|
break;
|
|
5844
|
+
case VK_VENDOR_ID_QUALCOMM:
|
|
5845
|
+
driver_priorities[vk::DriverId::eQualcommProprietary] = 1;
|
|
5846
|
+
driver_priorities[vk::DriverId::eMesaTurnip] = 2;
|
|
5847
|
+
break;
|
|
5568
5848
|
}
|
|
5569
5849
|
driver_priorities[vk::DriverId::eMesaDozen] = 100;
|
|
5570
5850
|
|
|
@@ -5647,7 +5927,15 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
|
|
|
5647
5927
|
ctx->almost_ready_fence = ctx->device->device.createFence({});
|
|
5648
5928
|
|
|
5649
5929
|
ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);
|
|
5650
|
-
|
|
5930
|
+
if (ctx->device->async_use_transfer_queue) {
|
|
5931
|
+
vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
|
|
5932
|
+
vk::SemaphoreCreateInfo ci{};
|
|
5933
|
+
ci.setPNext(&tci);
|
|
5934
|
+
ctx->transfer_semaphore.s = ctx->device->device.createSemaphore(ci);
|
|
5935
|
+
ctx->transfer_semaphore.value = 0;
|
|
5936
|
+
|
|
5937
|
+
ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);
|
|
5938
|
+
}
|
|
5651
5939
|
|
|
5652
5940
|
if (vk_perf_logger_enabled) {
|
|
5653
5941
|
ctx->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
|
|
@@ -6100,13 +6388,24 @@ static vk_subbuffer ggml_vk_tensor_subbuffer(
|
|
|
6100
6388
|
return vk_subbuffer{buffer, offset, size};
|
|
6101
6389
|
}
|
|
6102
6390
|
|
|
6391
|
+
// Get a command buffer from pool. Create a new one if no reusable buffer is available
|
|
6392
|
+
static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) {
|
|
6393
|
+
for (auto& cmd_buffer : pool.cmd_buffers) {
|
|
6394
|
+
if (!cmd_buffer.in_use) {
|
|
6395
|
+
cmd_buffer.in_use = true;
|
|
6396
|
+
return &cmd_buffer;
|
|
6397
|
+
}
|
|
6398
|
+
}
|
|
6399
|
+
return ggml_vk_create_cmd_buffer(device, pool);
|
|
6400
|
+
}
|
|
6401
|
+
|
|
6103
6402
|
static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
|
|
6104
6403
|
vk_submission s;
|
|
6105
|
-
s.buffer =
|
|
6404
|
+
s.buffer = ggml_vk_get_or_create_cmd_buffer(device, p);
|
|
6106
6405
|
if (one_time) {
|
|
6107
|
-
s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
|
|
6406
|
+
s.buffer->buf.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
|
|
6108
6407
|
} else {
|
|
6109
|
-
s.buffer.begin({ vk::CommandBufferUsageFlags{} });
|
|
6408
|
+
s.buffer->buf.begin({ vk::CommandBufferUsageFlags{} });
|
|
6110
6409
|
}
|
|
6111
6410
|
|
|
6112
6411
|
return s;
|
|
@@ -6159,18 +6458,18 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
|
|
|
6159
6458
|
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
|
|
6160
6459
|
ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
|
|
6161
6460
|
|
|
6162
|
-
subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
|
|
6163
|
-
subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
|
|
6164
|
-
subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
|
|
6461
|
+
subctx->s->buffer->buf.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
|
|
6462
|
+
subctx->s->buffer->buf.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
|
|
6463
|
+
subctx->s->buffer->buf.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
|
|
6165
6464
|
pipeline->layout,
|
|
6166
6465
|
0,
|
|
6167
6466
|
{ descriptor_set },
|
|
6168
6467
|
{});
|
|
6169
|
-
subctx->s->buffer.dispatch(wg0, wg1, wg2);
|
|
6468
|
+
subctx->s->buffer->buf.dispatch(wg0, wg1, wg2);
|
|
6170
6469
|
}
|
|
6171
6470
|
|
|
6172
6471
|
static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
|
|
6173
|
-
s.buffer.end();
|
|
6472
|
+
s.buffer->buf.end();
|
|
6174
6473
|
|
|
6175
6474
|
s.wait_semaphores = std::move(wait_semaphores);
|
|
6176
6475
|
s.signal_semaphores = std::move(signal_semaphores);
|
|
@@ -6182,7 +6481,7 @@ static void ggml_vk_ctx_end(vk_context& ctx) {
|
|
|
6182
6481
|
return;
|
|
6183
6482
|
}
|
|
6184
6483
|
|
|
6185
|
-
ctx->s->buffer.end();
|
|
6484
|
+
ctx->s->buffer->buf.end();
|
|
6186
6485
|
ctx->s = nullptr;
|
|
6187
6486
|
}
|
|
6188
6487
|
|
|
@@ -6196,6 +6495,47 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
|
|
|
6196
6495
|
subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
|
|
6197
6496
|
}
|
|
6198
6497
|
|
|
6498
|
+
static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) {
|
|
6499
|
+
if (!ctx->compute_ctx.expired()) {
|
|
6500
|
+
return ctx->compute_ctx.lock();
|
|
6501
|
+
}
|
|
6502
|
+
|
|
6503
|
+
vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
|
6504
|
+
|
|
6505
|
+
ctx->compute_ctx = result;
|
|
6506
|
+
ggml_vk_ctx_begin(ctx->device, result);
|
|
6507
|
+
|
|
6508
|
+
if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
|
|
6509
|
+
result->s->wait_semaphores.push_back(ctx->transfer_semaphore);
|
|
6510
|
+
ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;
|
|
6511
|
+
}
|
|
6512
|
+
|
|
6513
|
+
return result;
|
|
6514
|
+
}
|
|
6515
|
+
|
|
6516
|
+
// Submit any pending transfer queue work and signal the transfer semaphore.
|
|
6517
|
+
// The next compute context created via ggml_vk_get_compute_ctx will wait on this semaphore.
|
|
6518
|
+
// Returns true if work was submitted.
|
|
6519
|
+
static bool ggml_vk_submit_transfer_ctx(ggml_backend_vk_context * ctx) {
|
|
6520
|
+
if (!ctx->device->async_use_transfer_queue || ctx->transfer_ctx.expired()) {
|
|
6521
|
+
return false;
|
|
6522
|
+
}
|
|
6523
|
+
|
|
6524
|
+
vk_context cpy_ctx = ctx->transfer_ctx.lock();
|
|
6525
|
+
ggml_vk_ctx_end(cpy_ctx);
|
|
6526
|
+
|
|
6527
|
+
for (auto& cpy : cpy_ctx->in_memcpys) {
|
|
6528
|
+
memcpy(cpy.dst, cpy.src, cpy.n);
|
|
6529
|
+
}
|
|
6530
|
+
|
|
6531
|
+
ctx->transfer_semaphore.value++;
|
|
6532
|
+
cpy_ctx->seqs.back().back().signal_semaphores.push_back(ctx->transfer_semaphore);
|
|
6533
|
+
|
|
6534
|
+
ggml_vk_submit(cpy_ctx, {});
|
|
6535
|
+
ctx->transfer_ctx.reset();
|
|
6536
|
+
return true;
|
|
6537
|
+
}
|
|
6538
|
+
|
|
6199
6539
|
static size_t ggml_vk_align_size(size_t width, size_t align) {
|
|
6200
6540
|
VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")");
|
|
6201
6541
|
return CEIL_DIV(width, align) * align;
|
|
@@ -6295,7 +6635,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
|
|
|
6295
6635
|
}
|
|
6296
6636
|
|
|
6297
6637
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
6298
|
-
subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
|
|
6638
|
+
subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices);
|
|
6299
6639
|
return;
|
|
6300
6640
|
}
|
|
6301
6641
|
|
|
@@ -6310,7 +6650,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
|
|
|
6310
6650
|
VkBufferCopy buf_copy{ 0, offset, copy_size };
|
|
6311
6651
|
|
|
6312
6652
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
6313
|
-
vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
|
|
6653
|
+
vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
|
|
6314
6654
|
|
|
6315
6655
|
for (uint64_t i3 = 0; i3 < ne3; i3++) {
|
|
6316
6656
|
for (uint64_t i2 = 0; i2 < ne2; i2++) {
|
|
@@ -6359,7 +6699,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
|
|
|
6359
6699
|
}
|
|
6360
6700
|
|
|
6361
6701
|
ggml_vk_sync_buffers(nullptr, subctx);
|
|
6362
|
-
subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
|
|
6702
|
+
subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices);
|
|
6363
6703
|
return true;
|
|
6364
6704
|
}
|
|
6365
6705
|
VK_LOG_DEBUG("STAGING");
|
|
@@ -6381,7 +6721,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
|
|
|
6381
6721
|
copy_size};
|
|
6382
6722
|
|
|
6383
6723
|
ggml_vk_sync_buffers(nullptr, subctx);
|
|
6384
|
-
vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
|
|
6724
|
+
vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
|
|
6385
6725
|
|
|
6386
6726
|
if (width == spitch) {
|
|
6387
6727
|
deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);
|
|
@@ -6467,7 +6807,7 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
|
|
|
6467
6807
|
if (buf != nullptr) {
|
|
6468
6808
|
// Memory is pinned, use as staging buffer
|
|
6469
6809
|
ggml_vk_sync_buffers(nullptr, subctx);
|
|
6470
|
-
subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
|
|
6810
|
+
subctx->s->buffer->buf.copyBuffer(src->buffer, buf->buffer, slices);
|
|
6471
6811
|
|
|
6472
6812
|
return true;
|
|
6473
6813
|
}
|
|
@@ -6485,7 +6825,7 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
|
|
|
6485
6825
|
vk_buffer& staging_buffer = src->device->sync_staging;
|
|
6486
6826
|
|
|
6487
6827
|
ggml_vk_sync_buffers(nullptr, subctx);
|
|
6488
|
-
subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
|
|
6828
|
+
subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, slices);
|
|
6489
6829
|
|
|
6490
6830
|
deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
|
|
6491
6831
|
return true;
|
|
@@ -6532,7 +6872,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds
|
|
|
6532
6872
|
|
|
6533
6873
|
VkBufferCopy bc{ src_offset, dst_offset, size };
|
|
6534
6874
|
|
|
6535
|
-
vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);
|
|
6875
|
+
vkCmdCopyBuffer(ctx->s->buffer->buf, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);
|
|
6536
6876
|
}
|
|
6537
6877
|
|
|
6538
6878
|
static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
|
|
@@ -6570,7 +6910,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t
|
|
|
6570
6910
|
}
|
|
6571
6911
|
|
|
6572
6912
|
// Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers
|
|
6573
|
-
ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
|
|
6913
|
+
ctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c);
|
|
6574
6914
|
}
|
|
6575
6915
|
|
|
6576
6916
|
static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
|
|
@@ -6585,7 +6925,7 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
|
|
|
6585
6925
|
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
|
|
6586
6926
|
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
|
|
6587
6927
|
ggml_vk_ctx_begin(dst->device, subctx);
|
|
6588
|
-
subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
|
|
6928
|
+
subctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c);
|
|
6589
6929
|
ggml_vk_ctx_end(subctx);
|
|
6590
6930
|
|
|
6591
6931
|
ggml_vk_submit(subctx, dst->device->fence);
|
|
@@ -6691,8 +7031,16 @@ static void ggml_vk_matmul(
|
|
|
6691
7031
|
uint32_t padded_n) {
|
|
6692
7032
|
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
|
|
6693
7033
|
if (split_k == 1) {
|
|
6694
|
-
|
|
6695
|
-
|
|
7034
|
+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
|
|
7035
|
+
|
|
7036
|
+
uint32_t base_work_group_z = 0;
|
|
7037
|
+
while (base_work_group_z < batch) {
|
|
7038
|
+
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
|
7039
|
+
|
|
7040
|
+
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
|
7041
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
|
|
7042
|
+
base_work_group_z += groups_z;
|
|
7043
|
+
}
|
|
6696
7044
|
return;
|
|
6697
7045
|
}
|
|
6698
7046
|
|
|
@@ -6706,9 +7054,17 @@ static void ggml_vk_matmul(
|
|
|
6706
7054
|
uint32_t k_split = CEIL_DIV(k, split_k);
|
|
6707
7055
|
k_split = ROUNDUP_POW2(k_split, 256);
|
|
6708
7056
|
|
|
6709
|
-
|
|
6710
|
-
|
|
6711
|
-
|
|
7057
|
+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
|
|
7058
|
+
|
|
7059
|
+
uint32_t base_work_group_z = 0;
|
|
7060
|
+
while (base_work_group_z < batch) {
|
|
7061
|
+
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
|
7062
|
+
|
|
7063
|
+
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
|
|
7064
|
+
// Make sure enough workgroups get assigned for split k to work
|
|
7065
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
|
|
7066
|
+
base_work_group_z += groups_z;
|
|
7067
|
+
}
|
|
6712
7068
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
6713
7069
|
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
|
|
6714
7070
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
|
|
@@ -7104,7 +7460,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
7104
7460
|
}
|
|
7105
7461
|
|
|
7106
7462
|
// Request descriptor sets
|
|
7107
|
-
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
7108
7463
|
if (qx_needs_dequant) {
|
|
7109
7464
|
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
|
|
7110
7465
|
}
|
|
@@ -7274,6 +7629,18 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
|
|
7274
7629
|
return false;
|
|
7275
7630
|
}
|
|
7276
7631
|
|
|
7632
|
+
if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) {
|
|
7633
|
+
// Intel Windows proprietary driver tuning
|
|
7634
|
+
switch (src0_type) {
|
|
7635
|
+
case GGML_TYPE_MXFP4:
|
|
7636
|
+
case GGML_TYPE_Q4_K:
|
|
7637
|
+
case GGML_TYPE_Q5_K:
|
|
7638
|
+
return false;
|
|
7639
|
+
default:
|
|
7640
|
+
return true;
|
|
7641
|
+
}
|
|
7642
|
+
}
|
|
7643
|
+
|
|
7277
7644
|
switch (src0_type) {
|
|
7278
7645
|
// From tests on A770 Linux, may need more tuning
|
|
7279
7646
|
case GGML_TYPE_Q4_0:
|
|
@@ -7402,7 +7769,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
7402
7769
|
if (quantize_y) {
|
|
7403
7770
|
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
|
|
7404
7771
|
}
|
|
7405
|
-
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
|
|
7406
7772
|
}
|
|
7407
7773
|
|
|
7408
7774
|
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
|
|
@@ -7497,22 +7863,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
7497
7863
|
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
|
|
7498
7864
|
}
|
|
7499
7865
|
|
|
7500
|
-
|
|
7501
|
-
|
|
7502
|
-
|
|
7503
|
-
|
|
7504
|
-
|
|
7505
|
-
|
|
7506
|
-
|
|
7507
|
-
|
|
7508
|
-
|
|
7509
|
-
|
|
7510
|
-
|
|
7511
|
-
|
|
7512
|
-
|
|
7513
|
-
|
|
7514
|
-
|
|
7515
|
-
|
|
7866
|
+
ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
|
|
7867
|
+
|
|
7868
|
+
uint32_t base_work_group_y = 0;
|
|
7869
|
+
while (base_work_group_y < ne12 * ne13) {
|
|
7870
|
+
|
|
7871
|
+
uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
|
7872
|
+
const vk_mat_vec_push_constants pc = {
|
|
7873
|
+
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
|
7874
|
+
stride_batch_x, stride_batch_y, stride_batch_d,
|
|
7875
|
+
fusion_flags, base_work_group_y,
|
|
7876
|
+
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
|
7877
|
+
};
|
|
7878
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
|
7879
|
+
{
|
|
7880
|
+
d_X,
|
|
7881
|
+
d_Y,
|
|
7882
|
+
d_D,
|
|
7883
|
+
d_F0,
|
|
7884
|
+
d_F1,
|
|
7885
|
+
},
|
|
7886
|
+
pc, { groups_x, groups_y, groups_z });
|
|
7887
|
+
base_work_group_y += groups_y;
|
|
7888
|
+
}
|
|
7516
7889
|
|
|
7517
7890
|
if (x_non_contig) {
|
|
7518
7891
|
ctx->prealloc_x_need_sync = true;
|
|
@@ -7750,10 +8123,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
7750
8123
|
src1->nb[2] <= src1->nb[1] &&
|
|
7751
8124
|
src1->nb[1] <= src1->nb[3] &&
|
|
7752
8125
|
src0->ne[3] == 1 &&
|
|
7753
|
-
src1->ne[3] == 1
|
|
8126
|
+
src1->ne[3] == 1 &&
|
|
8127
|
+
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
|
|
8128
|
+
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
|
|
7754
8129
|
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
|
|
7755
8130
|
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
|
|
7756
|
-
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)
|
|
8131
|
+
!ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
|
|
8132
|
+
src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
|
|
8133
|
+
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
|
|
8134
|
+
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
|
|
7757
8135
|
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
|
|
7758
8136
|
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
|
7759
8137
|
// when ne12 and ne13 are one.
|
|
@@ -8083,8 +8461,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
8083
8461
|
|
|
8084
8462
|
const uint64_t nei0 = ids->ne[0];
|
|
8085
8463
|
const uint64_t nei1 = ids->ne[1];
|
|
8086
|
-
|
|
8087
|
-
GGML_ASSERT(nei1 == 1);
|
|
8464
|
+
const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
|
|
8088
8465
|
|
|
8089
8466
|
const uint64_t ne20 = dst->ne[0];
|
|
8090
8467
|
const uint64_t ne21 = dst->ne[1];
|
|
@@ -8168,7 +8545,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
8168
8545
|
if (quantize_y) {
|
|
8169
8546
|
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
|
|
8170
8547
|
}
|
|
8171
|
-
ggml_pipeline_request_descriptor_sets(ctx, dmmv,
|
|
8548
|
+
ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
|
|
8172
8549
|
}
|
|
8173
8550
|
|
|
8174
8551
|
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
|
|
@@ -8226,7 +8603,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
8226
8603
|
uint32_t stride_batch_y = ne10*ne11;
|
|
8227
8604
|
|
|
8228
8605
|
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
|
|
8229
|
-
stride_batch_y = src1->nb[
|
|
8606
|
+
stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
|
|
8230
8607
|
}
|
|
8231
8608
|
|
|
8232
8609
|
const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
|
|
@@ -8262,23 +8639,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
8262
8639
|
fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
|
|
8263
8640
|
}
|
|
8264
8641
|
|
|
8265
|
-
//
|
|
8266
|
-
|
|
8267
|
-
|
|
8268
|
-
|
|
8269
|
-
|
|
8270
|
-
|
|
8271
|
-
|
|
8272
|
-
|
|
8273
|
-
|
|
8274
|
-
|
|
8275
|
-
|
|
8276
|
-
|
|
8277
|
-
|
|
8278
|
-
|
|
8279
|
-
|
|
8280
|
-
|
|
8281
|
-
|
|
8642
|
+
// Loop over the batch dimension
|
|
8643
|
+
for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
|
|
8644
|
+
const vk_mat_vec_id_push_constants pc = {
|
|
8645
|
+
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
|
8646
|
+
(uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
|
|
8647
|
+
fusion_flags,
|
|
8648
|
+
(uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
|
|
8649
|
+
};
|
|
8650
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
|
8651
|
+
{
|
|
8652
|
+
d_X,
|
|
8653
|
+
d_Y,
|
|
8654
|
+
d_D,
|
|
8655
|
+
d_F0,
|
|
8656
|
+
d_F1,
|
|
8657
|
+
d_ids,
|
|
8658
|
+
},
|
|
8659
|
+
pc, { groups_x, (uint32_t)nei0, groups_z });
|
|
8660
|
+
}
|
|
8282
8661
|
|
|
8283
8662
|
if (x_non_contig) {
|
|
8284
8663
|
ctx->prealloc_x_need_sync = true;
|
|
@@ -8292,7 +8671,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no
|
|
|
8292
8671
|
ggml_tensor * dst = cgraph->nodes[node_idx];
|
|
8293
8672
|
ggml_tensor * src0 = dst->src[0];
|
|
8294
8673
|
ggml_tensor * src2 = dst->src[2];
|
|
8295
|
-
return src2->ne[1]
|
|
8674
|
+
return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
|
|
8296
8675
|
}
|
|
8297
8676
|
|
|
8298
8677
|
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
|
@@ -8308,55 +8687,70 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8308
8687
|
}
|
|
8309
8688
|
}
|
|
8310
8689
|
|
|
8311
|
-
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool
|
|
8690
|
+
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
|
|
8691
|
+
GGML_UNUSED(f32acc);
|
|
8312
8692
|
// Needs to be kept up to date on shader changes
|
|
8313
|
-
|
|
8314
|
-
const uint32_t
|
|
8315
|
-
const uint32_t
|
|
8316
|
-
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
8693
|
+
const uint32_t wg_size = params.workgroup_size;
|
|
8694
|
+
const uint32_t Br = params.block_rows;
|
|
8695
|
+
const uint32_t Bc = params.block_cols;
|
|
8317
8696
|
|
|
8697
|
+
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
|
8698
|
+
|
|
8699
|
+
// tmpsh is overestimated slightly
|
|
8318
8700
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
8319
|
-
const uint32_t tmpshv4 = wg_size * 4 *
|
|
8701
|
+
const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
|
|
8702
|
+
|
|
8703
|
+
const uint32_t masksh = Bc * (Br + 1) * float_type_size;
|
|
8320
8704
|
|
|
8321
|
-
const uint32_t
|
|
8705
|
+
const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
|
8322
8706
|
|
|
8323
|
-
const uint32_t
|
|
8707
|
+
const uint32_t D = std::max(hsk, hsv);
|
|
8708
|
+
const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
|
8324
8709
|
|
|
8325
|
-
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
|
|
8710
|
+
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
|
|
8326
8711
|
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
8327
8712
|
|
|
8328
|
-
VK_LOG_DEBUG("
|
|
8713
|
+
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
|
8329
8714
|
|
|
8330
8715
|
return supported;
|
|
8331
8716
|
}
|
|
8332
8717
|
|
|
8333
|
-
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
|
|
8718
|
+
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
|
|
8334
8719
|
// Needs to be kept up to date on shader changes
|
|
8335
|
-
|
|
8336
|
-
const uint32_t
|
|
8337
|
-
|
|
8338
|
-
const uint32_t
|
|
8720
|
+
const uint32_t Br = params.block_rows;
|
|
8721
|
+
const uint32_t Bc = params.block_cols;
|
|
8722
|
+
|
|
8723
|
+
const uint32_t MatBr = 16, MatBc = 16;
|
|
8724
|
+
|
|
8725
|
+
const uint32_t row_split = Bc / MatBc;
|
|
8339
8726
|
|
|
8340
8727
|
const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
|
|
8728
|
+
const uint32_t hsv_pad = ROUNDUP_POW2(hsv, 16);
|
|
8341
8729
|
|
|
8342
8730
|
const uint32_t acctype = f32acc ? 4 : 2;
|
|
8343
8731
|
const uint32_t f16vec4 = 8;
|
|
8344
8732
|
|
|
8345
|
-
const uint32_t tmpsh =
|
|
8346
|
-
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
|
8733
|
+
const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
|
|
8347
8734
|
|
|
8348
8735
|
const uint32_t qstride = hsk_pad / 4 + 2;
|
|
8349
8736
|
const uint32_t Qf = Br * qstride * f16vec4;
|
|
8350
8737
|
|
|
8738
|
+
const uint32_t psh_stride = Br / 4 + 2;
|
|
8739
|
+
const uint32_t Psh = Bc * psh_stride * f16vec4;
|
|
8740
|
+
|
|
8351
8741
|
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
|
|
8352
8742
|
const uint32_t sfsh = Bc * sfshstride * acctype;
|
|
8353
8743
|
|
|
8354
|
-
const uint32_t
|
|
8355
|
-
const uint32_t
|
|
8744
|
+
const uint32_t kvshstride = (params.shmem_staging ? std::max(hsk_pad, hsv_pad) : MatBr) / 4 + 2;
|
|
8745
|
+
const uint32_t vsh_stride = MatBc / 4 * row_split;
|
|
8746
|
+
const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4;
|
|
8356
8747
|
|
|
8357
|
-
const uint32_t
|
|
8748
|
+
const uint32_t osh_stride = params.row_split * MatBr / 4;
|
|
8749
|
+
const uint32_t pvsh = MatBc * osh_stride * f16vec4;
|
|
8358
8750
|
|
|
8359
|
-
const uint32_t
|
|
8751
|
+
const uint32_t slope = Br * acctype;
|
|
8752
|
+
|
|
8753
|
+
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + pvsh + slope;
|
|
8360
8754
|
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
8361
8755
|
|
|
8362
8756
|
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
|
@@ -8383,6 +8777,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8383
8777
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
8384
8778
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
8385
8779
|
|
|
8780
|
+
const uint32_t nem0 = mask ? mask->ne[0] : 0;
|
|
8386
8781
|
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
|
8387
8782
|
const uint32_t nem2 = mask ? mask->ne[2] : 0;
|
|
8388
8783
|
const uint32_t nem3 = mask ? mask->ne[3] : 0;
|
|
@@ -8416,72 +8811,30 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8416
8811
|
assert(q->type == GGML_TYPE_F32);
|
|
8417
8812
|
assert(k->type == v->type);
|
|
8418
8813
|
|
|
8419
|
-
FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
|
|
8420
|
-
ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
|
8421
|
-
|
|
8422
|
-
if (path == FA_COOPMAT1) {
|
|
8423
|
-
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
|
8424
|
-
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
|
8425
|
-
|
|
8426
|
-
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
|
|
8427
|
-
|
|
8428
|
-
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
|
8429
|
-
path = FA_SCALAR;
|
|
8430
|
-
}
|
|
8431
|
-
}
|
|
8432
|
-
|
|
8433
8814
|
uint32_t gqa_ratio = 1;
|
|
8434
8815
|
uint32_t qk_ratio = neq2 / nek2;
|
|
8435
8816
|
uint32_t workgroups_x = (uint32_t)neq1;
|
|
8436
8817
|
uint32_t workgroups_y = (uint32_t)neq2;
|
|
8437
8818
|
uint32_t workgroups_z = (uint32_t)neq3;
|
|
8438
8819
|
|
|
8439
|
-
const bool
|
|
8820
|
+
const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;
|
|
8440
8821
|
|
|
8441
8822
|
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
|
8442
8823
|
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
|
8443
|
-
|
|
8444
|
-
|
|
8445
|
-
case FA_SCALAR:
|
|
8446
|
-
case FA_COOPMAT1:
|
|
8447
|
-
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
|
8448
|
-
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
|
|
8449
|
-
break;
|
|
8450
|
-
case FA_COOPMAT2:
|
|
8451
|
-
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
|
8452
|
-
break;
|
|
8453
|
-
default:
|
|
8454
|
-
GGML_ASSERT(0);
|
|
8455
|
-
}
|
|
8824
|
+
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
|
|
8825
|
+
const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
|
|
8456
8826
|
|
|
8457
|
-
if (N
|
|
8827
|
+
if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
|
8458
8828
|
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
|
|
8459
8829
|
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
|
8460
8830
|
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
|
8461
8831
|
// and change addressing calculations to index Q's dimension 2.
|
|
8462
8832
|
gqa_ratio = qk_ratio;
|
|
8463
8833
|
N = gqa_ratio;
|
|
8464
|
-
workgroups_y /=
|
|
8465
|
-
}
|
|
8466
|
-
|
|
8467
|
-
bool small_rows = N <= get_fa_num_small_rows(path);
|
|
8468
|
-
|
|
8469
|
-
// coopmat1 does not actually support "small rows" (it needs 16 rows).
|
|
8470
|
-
// So use scalar instead.
|
|
8471
|
-
if (small_rows && path == FA_COOPMAT1) {
|
|
8472
|
-
path = FA_SCALAR;
|
|
8834
|
+
workgroups_y /= gqa_ratio;
|
|
8473
8835
|
}
|
|
8474
8836
|
|
|
8475
|
-
|
|
8476
|
-
if (N == 1 && path == FA_COOPMAT2) {
|
|
8477
|
-
path = FA_SCALAR;
|
|
8478
|
-
}
|
|
8479
|
-
|
|
8480
|
-
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
|
|
8481
|
-
if (path == FA_SCALAR &&
|
|
8482
|
-
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
|
|
8483
|
-
small_rows = true;
|
|
8484
|
-
}
|
|
8837
|
+
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
|
|
8485
8838
|
|
|
8486
8839
|
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
|
8487
8840
|
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
|
@@ -8495,19 +8848,32 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8495
8848
|
v_stride /= 4;
|
|
8496
8849
|
}
|
|
8497
8850
|
|
|
8498
|
-
uint32_t alignment =
|
|
8851
|
+
const uint32_t alignment = tuning_params.block_cols;
|
|
8499
8852
|
bool aligned = (KV % alignment) == 0 &&
|
|
8500
8853
|
// the "aligned" shader variant will forcibly align strides, for performance
|
|
8501
8854
|
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
|
8502
8855
|
|
|
8503
8856
|
// Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
|
|
8504
|
-
if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
|
|
8857
|
+
if (((HSK | HSV) % 16) != 0 && tuning_params.path == FA_COOPMAT2) {
|
|
8505
8858
|
aligned = false;
|
|
8506
8859
|
}
|
|
8507
8860
|
|
|
8508
|
-
|
|
8861
|
+
float scale = 1.0f;
|
|
8862
|
+
float max_bias = 0.0f;
|
|
8863
|
+
float logit_softcap = 0.0f;
|
|
8509
8864
|
|
|
8510
|
-
|
|
8865
|
+
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
|
8866
|
+
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
|
8867
|
+
memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
|
|
8868
|
+
|
|
8869
|
+
if (logit_softcap != 0) {
|
|
8870
|
+
scale /= logit_softcap;
|
|
8871
|
+
}
|
|
8872
|
+
|
|
8873
|
+
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
|
|
8874
|
+
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
|
|
8875
|
+
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
|
|
8876
|
+
mask != nullptr, use_mask_opt, logit_softcap != 0);
|
|
8511
8877
|
|
|
8512
8878
|
vk_pipeline pipeline = nullptr;
|
|
8513
8879
|
|
|
@@ -8523,29 +8889,46 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8523
8889
|
}
|
|
8524
8890
|
|
|
8525
8891
|
assert(pipeline);
|
|
8892
|
+
// Compile early to initialize wg_denoms.
|
|
8893
|
+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
8526
8894
|
|
|
8527
8895
|
uint32_t split_kv = KV;
|
|
8528
8896
|
uint32_t split_k = 1;
|
|
8529
8897
|
|
|
8898
|
+
// Intel Alchemist prefers more workgroups
|
|
8899
|
+
const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1;
|
|
8900
|
+
|
|
8530
8901
|
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
|
|
8531
|
-
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
|
8902
|
+
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16;
|
|
8532
8903
|
|
|
8533
|
-
|
|
8534
|
-
|
|
8535
|
-
|
|
8536
|
-
|
|
8537
|
-
|
|
8538
|
-
|
|
8539
|
-
|
|
8540
|
-
|
|
8541
|
-
|
|
8542
|
-
|
|
8904
|
+
const uint32_t Br = fa_pipeline_state.Br;
|
|
8905
|
+
const uint32_t Bc = fa_pipeline_state.Bc;
|
|
8906
|
+
|
|
8907
|
+
GGML_ASSERT(Br == pipeline->wg_denoms[0]);
|
|
8908
|
+
const uint32_t Tr = CEIL_DIV(N, Br);
|
|
8909
|
+
|
|
8910
|
+
// Try to use split_k when KV is large enough to be worth the overhead.
|
|
8911
|
+
if (gqa_ratio > 1 && workgroups_x <= Br) {
|
|
8912
|
+
split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
|
|
8913
|
+
} else if (gqa_ratio <= 1) {
|
|
8914
|
+
uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z;
|
|
8915
|
+
if (total_wgs_no_split < shader_core_count * 2) {
|
|
8916
|
+
split_k = shader_core_count * 2 / total_wgs_no_split;
|
|
8543
8917
|
}
|
|
8544
8918
|
}
|
|
8545
8919
|
|
|
8920
|
+
if (split_k > 1) {
|
|
8921
|
+
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
|
8922
|
+
// of "align", so recompute split_k based on that.
|
|
8923
|
+
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
|
|
8924
|
+
split_k = CEIL_DIV(KV, split_kv);
|
|
8925
|
+
}
|
|
8926
|
+
|
|
8546
8927
|
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
|
8547
8928
|
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
|
8548
|
-
|
|
8929
|
+
// For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
|
|
8930
|
+
// For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].
|
|
8931
|
+
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;
|
|
8549
8932
|
if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
|
|
8550
8933
|
GGML_ABORT("Requested preallocation size is too large");
|
|
8551
8934
|
}
|
|
@@ -8554,24 +8937,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8554
8937
|
ggml_vk_preallocate_buffers(ctx, subctx);
|
|
8555
8938
|
}
|
|
8556
8939
|
|
|
8557
|
-
|
|
8558
|
-
|
|
8559
|
-
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
8560
|
-
if (split_k > 1) {
|
|
8561
|
-
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
|
|
8562
|
-
}
|
|
8563
|
-
}
|
|
8564
|
-
|
|
8565
|
-
float scale = 1.0f;
|
|
8566
|
-
float max_bias = 0.0f;
|
|
8567
|
-
float logit_softcap = 0.0f;
|
|
8940
|
+
const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
|
|
8941
|
+
const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
|
|
8568
8942
|
|
|
8569
|
-
|
|
8570
|
-
|
|
8571
|
-
|
|
8943
|
+
vk_pipeline pipeline_fa_mask_opt = nullptr;
|
|
8944
|
+
if (use_mask_opt) {
|
|
8945
|
+
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
|
|
8946
|
+
auto &pipelines = ctx->device->pipeline_fa_mask_opt;
|
|
8947
|
+
auto it = pipelines.find({Br, Bc});
|
|
8948
|
+
if (it != pipelines.end()) {
|
|
8949
|
+
pipeline_fa_mask_opt = it->second;
|
|
8950
|
+
} else {
|
|
8951
|
+
pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
|
|
8952
|
+
}
|
|
8953
|
+
assert(pipeline_fa_mask_opt);
|
|
8954
|
+
ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
|
|
8572
8955
|
|
|
8573
|
-
|
|
8574
|
-
|
|
8956
|
+
if (ctx->prealloc_size_y < mask_opt_size) {
|
|
8957
|
+
ctx->prealloc_size_y = mask_opt_size;
|
|
8958
|
+
ggml_vk_preallocate_buffers(ctx, subctx);
|
|
8959
|
+
}
|
|
8960
|
+
if (ctx->prealloc_y_need_sync) {
|
|
8961
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
8962
|
+
}
|
|
8575
8963
|
}
|
|
8576
8964
|
|
|
8577
8965
|
const uint32_t n_head_kv = neq2;
|
|
@@ -8585,8 +8973,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8585
8973
|
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
|
8586
8974
|
vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
|
|
8587
8975
|
vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
|
|
8976
|
+
vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
|
|
8977
|
+
|
|
8978
|
+
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
|
|
8979
|
+
|
|
8980
|
+
if (use_mask_opt)
|
|
8981
|
+
{
|
|
8982
|
+
const vk_op_flash_attn_mask_opt_push_constants opt_pc = {
|
|
8983
|
+
nem0,
|
|
8984
|
+
nem1,
|
|
8985
|
+
nem2,
|
|
8986
|
+
(uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)),
|
|
8987
|
+
(uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)),
|
|
8988
|
+
(uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)),
|
|
8989
|
+
mask_opt_num_dwords,
|
|
8990
|
+
mask_opt_num_dwords * CEIL_DIV(nem1, Br),
|
|
8991
|
+
mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2,
|
|
8992
|
+
};
|
|
8588
8993
|
|
|
8589
|
-
|
|
8994
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt,
|
|
8995
|
+
{ mask_buf, mask_opt_buf }, opt_pc,
|
|
8996
|
+
{ mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 });
|
|
8997
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
8998
|
+
}
|
|
8590
8999
|
|
|
8591
9000
|
const vk_flash_attn_push_constants pc = { N, KV,
|
|
8592
9001
|
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
|
@@ -8602,28 +9011,40 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8602
9011
|
gqa_ratio, split_kv, split_k };
|
|
8603
9012
|
|
|
8604
9013
|
if (split_k > 1) {
|
|
9014
|
+
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
|
|
9015
|
+
|
|
8605
9016
|
if (ctx->prealloc_split_k_need_sync) {
|
|
8606
9017
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
8607
9018
|
}
|
|
8608
9019
|
|
|
9020
|
+
// We reuse workgroups_x to mean the number of splits, so we need to
|
|
9021
|
+
// cancel out the divide by wg_denoms[0].
|
|
9022
|
+
uint32_t dispatch_x;
|
|
9023
|
+
if (gqa_ratio > 1) {
|
|
9024
|
+
workgroups_x *= pipeline->wg_denoms[0];
|
|
9025
|
+
dispatch_x = split_k * workgroups_x;
|
|
9026
|
+
} else {
|
|
9027
|
+
dispatch_x = Tr * split_k * pipeline->wg_denoms[0];
|
|
9028
|
+
}
|
|
9029
|
+
|
|
8609
9030
|
vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
|
|
8610
9031
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
8611
|
-
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
|
|
8612
|
-
|
|
8613
|
-
// there's no more than one tile of rows (i.e. workgroups_x would have been
|
|
8614
|
-
// one). We reuse workgroups_x to mean the number of splits, so we need to
|
|
8615
|
-
// cancel out the divide by wg_denoms[0].
|
|
8616
|
-
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
|
9032
|
+
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
|
|
9033
|
+
pc, { dispatch_x, workgroups_y, workgroups_z });
|
|
8617
9034
|
|
|
8618
9035
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
8619
|
-
const
|
|
9036
|
+
const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
|
|
8620
9037
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
|
8621
9038
|
{split_k_buf, sinks_buf, dst_buf},
|
|
8622
|
-
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
|
9039
|
+
pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });
|
|
8623
9040
|
ctx->prealloc_split_k_need_sync = true;
|
|
8624
9041
|
} else {
|
|
9042
|
+
if (gqa_ratio > 1) {
|
|
9043
|
+
// When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms
|
|
9044
|
+
workgroups_x *= pipeline->wg_denoms[0];
|
|
9045
|
+
}
|
|
8625
9046
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
8626
|
-
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
|
|
9047
|
+
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf},
|
|
8627
9048
|
pc, { workgroups_x, workgroups_y, workgroups_z });
|
|
8628
9049
|
}
|
|
8629
9050
|
}
|
|
@@ -8668,6 +9089,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
8668
9089
|
return ctx->device->pipeline_acc_f32;
|
|
8669
9090
|
}
|
|
8670
9091
|
return nullptr;
|
|
9092
|
+
case GGML_OP_SET:
|
|
9093
|
+
if (src0->type == src1->type && src0->type == dst->type &&
|
|
9094
|
+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
|
|
9095
|
+
return ctx->device->pipeline_set_f32;
|
|
9096
|
+
}
|
|
9097
|
+
return nullptr;
|
|
8671
9098
|
case GGML_OP_ADD:
|
|
8672
9099
|
case GGML_OP_SUB:
|
|
8673
9100
|
case GGML_OP_MUL:
|
|
@@ -8869,6 +9296,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
8869
9296
|
switch (ggml_get_unary_op(dst)) {
|
|
8870
9297
|
case GGML_UNARY_OP_EXP:
|
|
8871
9298
|
return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
|
|
9299
|
+
case GGML_UNARY_OP_ELU:
|
|
9300
|
+
return ctx->device->pipeline_elu[dst->type == GGML_TYPE_F16];
|
|
8872
9301
|
case GGML_UNARY_OP_SILU:
|
|
8873
9302
|
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
|
|
8874
9303
|
case GGML_UNARY_OP_GELU:
|
|
@@ -8905,6 +9334,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
8905
9334
|
return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
|
|
8906
9335
|
case GGML_UNARY_OP_TRUNC:
|
|
8907
9336
|
return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
|
|
9337
|
+
case GGML_UNARY_OP_SGN:
|
|
9338
|
+
return ctx->device->pipeline_sgn[dst->type == GGML_TYPE_F16];
|
|
8908
9339
|
default:
|
|
8909
9340
|
break;
|
|
8910
9341
|
}
|
|
@@ -9098,6 +9529,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
9098
9529
|
return ctx->device->pipeline_rwkv_wkv7_f32;
|
|
9099
9530
|
}
|
|
9100
9531
|
return nullptr;
|
|
9532
|
+
case GGML_OP_GATED_DELTA_NET:
|
|
9533
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
9534
|
+
const uint32_t S_v = dst->src[2]->ne[0];
|
|
9535
|
+
const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0;
|
|
9536
|
+
uint32_t si;
|
|
9537
|
+
switch (S_v) {
|
|
9538
|
+
case 32: si = 0; break;
|
|
9539
|
+
case 64: si = 1; break;
|
|
9540
|
+
case 128: si = 2; break;
|
|
9541
|
+
default: return nullptr;
|
|
9542
|
+
}
|
|
9543
|
+
return ctx->device->pipeline_gated_delta_net[si][kda];
|
|
9544
|
+
}
|
|
9545
|
+
return nullptr;
|
|
9101
9546
|
case GGML_OP_SSM_SCAN:
|
|
9102
9547
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
9103
9548
|
const uint32_t d_state = src0->ne[0];
|
|
@@ -9654,16 +10099,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
9654
10099
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
9655
10100
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
9656
10101
|
|
|
9657
|
-
int nb1 = dst->op_params[0] /
|
|
9658
|
-
int nb2 = dst->op_params[1] /
|
|
9659
|
-
|
|
9660
|
-
int offset = dst->op_params[3] /
|
|
10102
|
+
int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32
|
|
10103
|
+
int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32
|
|
10104
|
+
int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
|
|
10105
|
+
int offset = dst->op_params[3] / src0_type_size; // offset in bytes
|
|
9661
10106
|
|
|
9662
|
-
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst,
|
|
10107
|
+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
|
|
9663
10108
|
(uint32_t)ggml_nelements(src0),
|
|
9664
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)
|
|
10109
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
|
|
9665
10110
|
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
|
9666
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)
|
|
10111
|
+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
|
|
9667
10112
|
0,
|
|
9668
10113
|
0.0f, 0.0f, offset,
|
|
9669
10114
|
});
|
|
@@ -9928,6 +10373,59 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
9928
10373
|
);
|
|
9929
10374
|
}
|
|
9930
10375
|
|
|
10376
|
+
static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
|
10377
|
+
const ggml_tensor * src_q = dst->src[0];
|
|
10378
|
+
const ggml_tensor * src_v = dst->src[2];
|
|
10379
|
+
const ggml_tensor * src_beta = dst->src[4];
|
|
10380
|
+
|
|
10381
|
+
GGML_ASSERT(dst->buffer != nullptr);
|
|
10382
|
+
|
|
10383
|
+
const uint32_t S_v = (uint32_t)src_v->ne[0];
|
|
10384
|
+
const uint32_t H = (uint32_t)src_v->ne[1];
|
|
10385
|
+
const uint32_t n_tokens = (uint32_t)src_v->ne[2];
|
|
10386
|
+
const uint32_t n_seqs = (uint32_t)src_v->ne[3];
|
|
10387
|
+
|
|
10388
|
+
const uint32_t s_off = S_v * H * n_tokens * n_seqs;
|
|
10389
|
+
|
|
10390
|
+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
|
|
10391
|
+
GGML_ASSERT(pipeline != nullptr);
|
|
10392
|
+
|
|
10393
|
+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
10394
|
+
|
|
10395
|
+
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
|
10396
|
+
vk_subbuffer src_buf[6] = {};
|
|
10397
|
+
for (int i = 0; i < 6; i++) {
|
|
10398
|
+
src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
|
|
10399
|
+
}
|
|
10400
|
+
|
|
10401
|
+
const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float));
|
|
10402
|
+
const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float));
|
|
10403
|
+
const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float));
|
|
10404
|
+
const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float));
|
|
10405
|
+
const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float));
|
|
10406
|
+
const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float));
|
|
10407
|
+
const uint32_t sb1 = (uint32_t)(src_beta->nb[1] / sizeof(float));
|
|
10408
|
+
const uint32_t sb2 = (uint32_t)(src_beta->nb[2] / sizeof(float));
|
|
10409
|
+
const uint32_t sb3 = (uint32_t)(src_beta->nb[3] / sizeof(float));
|
|
10410
|
+
|
|
10411
|
+
const uint32_t neq1 = (uint32_t)src_q->ne[1];
|
|
10412
|
+
const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]);
|
|
10413
|
+
|
|
10414
|
+
const float scale = 1.0f / sqrtf((float)S_v);
|
|
10415
|
+
const vk_op_gated_delta_net_push_constants pc = {
|
|
10416
|
+
H, n_tokens, n_seqs, s_off,
|
|
10417
|
+
sq1, sq2, sq3,
|
|
10418
|
+
sv1, sv2, sv3,
|
|
10419
|
+
sb1, sb2, sb3,
|
|
10420
|
+
neq1, rq3,
|
|
10421
|
+
scale
|
|
10422
|
+
};
|
|
10423
|
+
|
|
10424
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
10425
|
+
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
|
|
10426
|
+
pc, { H, n_seqs, 1u });
|
|
10427
|
+
}
|
|
10428
|
+
|
|
9931
10429
|
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
|
9932
10430
|
const ggml_tensor * src0 = dst->src[0];
|
|
9933
10431
|
const ggml_tensor * src1 = dst->src[1];
|
|
@@ -10335,12 +10833,22 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
|
|
|
10335
10833
|
|
|
10336
10834
|
uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
|
|
10337
10835
|
uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
|
|
10836
|
+
uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type);
|
|
10837
|
+
|
|
10838
|
+
uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
|
|
10839
|
+
uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
|
|
10840
|
+
uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type);
|
|
10338
10841
|
|
|
10339
10842
|
vk_op_rope_push_constants rope {
|
|
10340
|
-
(uint32_t)mode, (uint32_t)
|
|
10341
|
-
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
|
10342
|
-
has_ff, (uint32_t)src0->ne[2], nb01, nb02,
|
|
10843
|
+
(uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale,
|
|
10844
|
+
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff,
|
|
10343
10845
|
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
|
|
10846
|
+
|
|
10847
|
+
(uint32_t)src0->ne[0],
|
|
10848
|
+
(uint32_t)src0->ne[1],
|
|
10849
|
+
(uint32_t)src0->ne[2],
|
|
10850
|
+
nb01, nb02, nb03,
|
|
10851
|
+
nb11, nb12, nb13,
|
|
10344
10852
|
};
|
|
10345
10853
|
|
|
10346
10854
|
return rope;
|
|
@@ -10467,8 +10975,10 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
10467
10975
|
}
|
|
10468
10976
|
|
|
10469
10977
|
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
10470
|
-
float * op_params = (float *)dst->op_params;
|
|
10471
|
-
|
|
10978
|
+
const float * op_params = (const float *)dst->op_params;
|
|
10979
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
|
10980
|
+
p.param1 = op_params[0];
|
|
10981
|
+
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p));
|
|
10472
10982
|
}
|
|
10473
10983
|
|
|
10474
10984
|
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
@@ -11386,7 +11896,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
|
11386
11896
|
}
|
|
11387
11897
|
}
|
|
11388
11898
|
|
|
11389
|
-
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
|
|
11390
11899
|
if (split_k > 1) {
|
|
11391
11900
|
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
|
|
11392
11901
|
|
|
@@ -11560,7 +12069,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
|
11560
12069
|
free(d_chk);
|
|
11561
12070
|
|
|
11562
12071
|
ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
|
|
11563
|
-
ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
|
|
11564
12072
|
|
|
11565
12073
|
ggml_vk_destroy_buffer(d_X);
|
|
11566
12074
|
ggml_vk_destroy_buffer(d_Y);
|
|
@@ -11896,7 +12404,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
|
11896
12404
|
// y[i] = i % k;
|
|
11897
12405
|
}
|
|
11898
12406
|
|
|
11899
|
-
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
|
|
11900
12407
|
if (split_k > 1) {
|
|
11901
12408
|
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
|
|
11902
12409
|
|
|
@@ -11909,7 +12416,8 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
|
11909
12416
|
}
|
|
11910
12417
|
}
|
|
11911
12418
|
if (mmq) {
|
|
11912
|
-
|
|
12419
|
+
vk_pipeline pipeline_quantize_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
|
12420
|
+
ggml_pipeline_request_descriptor_sets(ctx, pipeline_quantize_q8_1, num_it);
|
|
11913
12421
|
}
|
|
11914
12422
|
|
|
11915
12423
|
ggml_pipeline_allocate_descriptor_sets(ctx);
|
|
@@ -12145,7 +12653,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
|
|
|
12145
12653
|
ggml_vk_submit(subctx, {});
|
|
12146
12654
|
ctx->submit_pending = true;
|
|
12147
12655
|
ggml_vk_synchronize(ctx);
|
|
12656
|
+
GGML_ASSERT(ctx->compute_ctx.expired());
|
|
12148
12657
|
ggml_vk_ctx_begin(ctx->device, subctx);
|
|
12658
|
+
ctx->compute_ctx = subctx;
|
|
12149
12659
|
}
|
|
12150
12660
|
|
|
12151
12661
|
if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {
|
|
@@ -12163,6 +12673,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
|
|
|
12163
12673
|
ggml_vk_destroy_buffer(ctx->prealloc_y);
|
|
12164
12674
|
}
|
|
12165
12675
|
ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
|
|
12676
|
+
ctx->prealloc_y_last_tensor_used = nullptr;
|
|
12166
12677
|
}
|
|
12167
12678
|
if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
|
|
12168
12679
|
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
|
|
@@ -12191,6 +12702,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
12191
12702
|
if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
|
|
12192
12703
|
return false;
|
|
12193
12704
|
}
|
|
12705
|
+
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
|
12706
|
+
return false;
|
|
12707
|
+
}
|
|
12194
12708
|
|
|
12195
12709
|
VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
|
|
12196
12710
|
ctx->semaphore_idx = 0;
|
|
@@ -12215,15 +12729,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
12215
12729
|
}
|
|
12216
12730
|
}
|
|
12217
12731
|
|
|
12218
|
-
vk_context compute_ctx;
|
|
12219
|
-
|
|
12220
|
-
if (ctx->compute_ctx.expired()) {
|
|
12221
|
-
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
|
12222
|
-
ctx->compute_ctx = compute_ctx;
|
|
12223
|
-
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
|
12224
|
-
} else {
|
|
12225
|
-
compute_ctx = ctx->compute_ctx.lock();
|
|
12226
|
-
}
|
|
12732
|
+
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
|
12227
12733
|
|
|
12228
12734
|
{
|
|
12229
12735
|
// This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers
|
|
@@ -12294,7 +12800,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
12294
12800
|
|
|
12295
12801
|
if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
|
|
12296
12802
|
ctx->query_node_idx[ctx->query_idx] = node_idx;
|
|
12297
|
-
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
|
12803
|
+
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
|
12298
12804
|
}
|
|
12299
12805
|
}
|
|
12300
12806
|
// Add all fused nodes to the unsynchronized lists.
|
|
@@ -12337,6 +12843,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
12337
12843
|
|
|
12338
12844
|
break;
|
|
12339
12845
|
case GGML_OP_ACC:
|
|
12846
|
+
case GGML_OP_SET:
|
|
12340
12847
|
ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
|
|
12341
12848
|
|
|
12342
12849
|
break;
|
|
@@ -12471,6 +12978,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
12471
12978
|
}
|
|
12472
12979
|
|
|
12473
12980
|
switch (ggml_get_unary_op(node)) {
|
|
12981
|
+
case GGML_UNARY_OP_ELU:
|
|
12474
12982
|
case GGML_UNARY_OP_EXP:
|
|
12475
12983
|
case GGML_UNARY_OP_SILU:
|
|
12476
12984
|
case GGML_UNARY_OP_GELU:
|
|
@@ -12489,6 +12997,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
12489
12997
|
case GGML_UNARY_OP_CEIL:
|
|
12490
12998
|
case GGML_UNARY_OP_FLOOR:
|
|
12491
12999
|
case GGML_UNARY_OP_TRUNC:
|
|
13000
|
+
case GGML_UNARY_OP_SGN:
|
|
12492
13001
|
ggml_vk_unary(ctx, compute_ctx, src0, node);
|
|
12493
13002
|
break;
|
|
12494
13003
|
case GGML_UNARY_OP_XIELU:
|
|
@@ -12633,6 +13142,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
12633
13142
|
|
|
12634
13143
|
break;
|
|
12635
13144
|
|
|
13145
|
+
case GGML_OP_GATED_DELTA_NET:
|
|
13146
|
+
ggml_vk_gated_delta_net(ctx, compute_ctx, node);
|
|
13147
|
+
|
|
13148
|
+
break;
|
|
13149
|
+
|
|
12636
13150
|
case GGML_OP_SSM_SCAN:
|
|
12637
13151
|
ggml_vk_ssm_scan(ctx, compute_ctx, node);
|
|
12638
13152
|
|
|
@@ -12740,7 +13254,9 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
|
|
|
12740
13254
|
ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
|
|
12741
13255
|
|
|
12742
13256
|
ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
|
|
12743
|
-
|
|
13257
|
+
if (ctx->device->async_use_transfer_queue) {
|
|
13258
|
+
ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
|
|
13259
|
+
}
|
|
12744
13260
|
|
|
12745
13261
|
for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
|
|
12746
13262
|
ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
|
|
@@ -12769,7 +13285,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
|
|
|
12769
13285
|
static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
|
12770
13286
|
VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
|
|
12771
13287
|
// discard any unsubmitted command buffers
|
|
12772
|
-
ctx->
|
|
13288
|
+
ctx->compute_ctx.reset();
|
|
12773
13289
|
// wait for any pending command buffers to finish
|
|
12774
13290
|
ggml_vk_synchronize(ctx);
|
|
12775
13291
|
|
|
@@ -12802,7 +13318,11 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
|
|
12802
13318
|
ctx->descriptor_sets.clear();
|
|
12803
13319
|
|
|
12804
13320
|
ctx->compute_cmd_pool.destroy(ctx->device->device);
|
|
12805
|
-
|
|
13321
|
+
if (ctx->device->async_use_transfer_queue) {
|
|
13322
|
+
ctx->device->device.destroySemaphore(ctx->transfer_semaphore.s);
|
|
13323
|
+
|
|
13324
|
+
ctx->transfer_cmd_pool.destroy(ctx->device->device);
|
|
13325
|
+
}
|
|
12806
13326
|
if (vk_perf_logger_enabled) {
|
|
12807
13327
|
ctx->perf_logger->print_timings(true);
|
|
12808
13328
|
}
|
|
@@ -12861,6 +13381,10 @@ static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, g
|
|
|
12861
13381
|
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
|
12862
13382
|
vk_buffer buf = buf_ctx->dev_buffer;
|
|
12863
13383
|
|
|
13384
|
+
if (size == 0) {
|
|
13385
|
+
return;
|
|
13386
|
+
}
|
|
13387
|
+
|
|
12864
13388
|
uint32_t val32 = (uint32_t)value * 0x01010101;
|
|
12865
13389
|
ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
|
|
12866
13390
|
}
|
|
@@ -12870,6 +13394,10 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml
|
|
|
12870
13394
|
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
|
12871
13395
|
vk_buffer buf = buf_ctx->dev_buffer;
|
|
12872
13396
|
|
|
13397
|
+
if (size == 0) {
|
|
13398
|
+
return;
|
|
13399
|
+
}
|
|
13400
|
+
|
|
12873
13401
|
ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
|
|
12874
13402
|
}
|
|
12875
13403
|
|
|
@@ -12877,12 +13405,20 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons
|
|
|
12877
13405
|
VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
|
|
12878
13406
|
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
|
12879
13407
|
|
|
13408
|
+
if (size == 0) {
|
|
13409
|
+
return;
|
|
13410
|
+
}
|
|
13411
|
+
|
|
12880
13412
|
vk_buffer buf = buf_ctx->dev_buffer;
|
|
12881
13413
|
|
|
12882
13414
|
ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
|
|
12883
13415
|
}
|
|
12884
13416
|
|
|
12885
13417
|
static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
|
13418
|
+
if (ggml_nbytes(src) == 0) {
|
|
13419
|
+
return true;
|
|
13420
|
+
}
|
|
13421
|
+
|
|
12886
13422
|
if (ggml_backend_buffer_is_vk(src->buffer)) {
|
|
12887
13423
|
ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
|
|
12888
13424
|
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
|
@@ -13072,36 +13608,44 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
|
|
|
13072
13608
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
13073
13609
|
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
|
|
13074
13610
|
|
|
13611
|
+
if (size == 0) {
|
|
13612
|
+
return;
|
|
13613
|
+
}
|
|
13614
|
+
|
|
13075
13615
|
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
|
13076
13616
|
|
|
13077
|
-
vk_context
|
|
13617
|
+
vk_context cpy_ctx;
|
|
13078
13618
|
|
|
13079
|
-
if (ctx->
|
|
13080
|
-
|
|
13081
|
-
|
|
13082
|
-
|
|
13083
|
-
|
|
13619
|
+
if (ctx->device->async_use_transfer_queue) {
|
|
13620
|
+
if (ctx->transfer_ctx.expired()) {
|
|
13621
|
+
// Initialize new transfer context
|
|
13622
|
+
cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
|
|
13623
|
+
ctx->transfer_ctx = cpy_ctx;
|
|
13624
|
+
ggml_vk_ctx_begin(ctx->device, cpy_ctx);
|
|
13625
|
+
} else {
|
|
13626
|
+
cpy_ctx = ctx->transfer_ctx.lock();
|
|
13627
|
+
}
|
|
13084
13628
|
} else {
|
|
13085
|
-
|
|
13629
|
+
cpy_ctx = ggml_vk_get_compute_ctx(ctx);
|
|
13086
13630
|
}
|
|
13087
13631
|
|
|
13088
13632
|
vk_buffer buf = buf_ctx->dev_buffer;
|
|
13089
13633
|
|
|
13090
13634
|
auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
13091
13635
|
|
|
13092
|
-
bool ret = ggml_vk_buffer_write_async(
|
|
13636
|
+
bool ret = ggml_vk_buffer_write_async(cpy_ctx, buf, dst_offset, data, size);
|
|
13093
13637
|
|
|
13094
13638
|
if (!ret) {
|
|
13095
13639
|
ggml_vk_ensure_sync_staging_buffer(ctx, size);
|
|
13096
|
-
ggml_vk_sync_buffers(nullptr,
|
|
13640
|
+
ggml_vk_sync_buffers(nullptr, cpy_ctx);
|
|
13097
13641
|
|
|
13098
13642
|
vk::BufferCopy buffer_cpy;
|
|
13099
13643
|
buffer_cpy.srcOffset = 0;
|
|
13100
13644
|
buffer_cpy.dstOffset = dst_offset;
|
|
13101
13645
|
buffer_cpy.size = size;
|
|
13102
13646
|
|
|
13103
|
-
|
|
13104
|
-
deferred_memcpy(ctx->sync_staging->ptr, data, size, &
|
|
13647
|
+
cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
|
|
13648
|
+
deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys);
|
|
13105
13649
|
ggml_vk_synchronize(ctx);
|
|
13106
13650
|
}
|
|
13107
13651
|
}
|
|
@@ -13111,101 +13655,156 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
|
|
|
13111
13655
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
13112
13656
|
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
|
|
13113
13657
|
|
|
13114
|
-
|
|
13658
|
+
if (size == 0) {
|
|
13659
|
+
return;
|
|
13660
|
+
}
|
|
13115
13661
|
|
|
13116
|
-
|
|
13662
|
+
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
|
13117
13663
|
|
|
13118
|
-
|
|
13119
|
-
// Initialize new transfer context
|
|
13120
|
-
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
|
13121
|
-
ctx->transfer_ctx = transfer_ctx;
|
|
13122
|
-
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
|
|
13123
|
-
} else {
|
|
13124
|
-
transfer_ctx = ctx->transfer_ctx.lock();
|
|
13125
|
-
}
|
|
13664
|
+
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
|
13126
13665
|
|
|
13127
13666
|
vk_buffer buf = buf_ctx->dev_buffer;
|
|
13128
13667
|
|
|
13129
13668
|
auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
13130
|
-
bool ret = ggml_vk_buffer_read_async(
|
|
13669
|
+
bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size);
|
|
13131
13670
|
|
|
13132
13671
|
// If that failed, copy synchronously through a staging buffer
|
|
13133
13672
|
if (!ret) {
|
|
13134
13673
|
ggml_vk_ensure_sync_staging_buffer(ctx, size);
|
|
13135
|
-
ggml_vk_sync_buffers(nullptr,
|
|
13674
|
+
ggml_vk_sync_buffers(nullptr, compute_ctx);
|
|
13136
13675
|
|
|
13137
13676
|
vk::BufferCopy buffer_cpy;
|
|
13138
13677
|
buffer_cpy.srcOffset = src_offset;
|
|
13139
13678
|
buffer_cpy.dstOffset = 0;
|
|
13140
13679
|
buffer_cpy.size = size;
|
|
13141
13680
|
|
|
13142
|
-
|
|
13143
|
-
deferred_memcpy(data, ctx->sync_staging->ptr, size, &
|
|
13681
|
+
compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
|
|
13682
|
+
deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys);
|
|
13144
13683
|
ggml_vk_synchronize(ctx);
|
|
13145
13684
|
}
|
|
13146
13685
|
}
|
|
13147
13686
|
|
|
13148
|
-
static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t
|
|
13149
|
-
VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
|
|
13150
|
-
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)
|
|
13151
|
-
if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
|
|
13152
|
-
ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
|
|
13153
|
-
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
|
13687
|
+
static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
|
|
13688
|
+
VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")");
|
|
13689
|
+
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context;
|
|
13154
13690
|
|
|
13155
|
-
|
|
13691
|
+
// Skip zero-size tensors
|
|
13692
|
+
if (ggml_nbytes(src) == 0) {
|
|
13693
|
+
return true;
|
|
13694
|
+
}
|
|
13156
13695
|
|
|
13157
|
-
|
|
13158
|
-
|
|
13159
|
-
|
|
13160
|
-
|
|
13161
|
-
|
|
13162
|
-
|
|
13163
|
-
|
|
13696
|
+
if (dst->buffer->buft != ggml_backend_vk_get_default_buffer_type(backend_dst)) {
|
|
13697
|
+
return false;
|
|
13698
|
+
}
|
|
13699
|
+
|
|
13700
|
+
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
|
13701
|
+
vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
|
|
13702
|
+
|
|
13703
|
+
if (ggml_backend_buffer_is_vk(src->buffer)) {
|
|
13704
|
+
ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
|
|
13705
|
+
|
|
13706
|
+
// Async copy only works within the same device
|
|
13707
|
+
if (src_buf_ctx->dev_buffer->device != dst_buf->device) {
|
|
13708
|
+
return false;
|
|
13164
13709
|
}
|
|
13165
13710
|
|
|
13166
|
-
|
|
13167
|
-
vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
|
|
13711
|
+
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
|
13168
13712
|
|
|
13169
|
-
ggml_vk_buffer_copy_async(
|
|
13713
|
+
ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs,
|
|
13714
|
+
src_buf_ctx->dev_buffer, vk_tensor_offset(src) + src->view_offs,
|
|
13715
|
+
ggml_nbytes(src));
|
|
13170
13716
|
return true;
|
|
13171
13717
|
}
|
|
13172
13718
|
|
|
13719
|
+
if (ggml_backend_buffer_is_host(src->buffer)) {
|
|
13720
|
+
vk_buffer pinned_buf = nullptr;
|
|
13721
|
+
size_t pinned_offset = 0;
|
|
13722
|
+
ggml_vk_host_get(ctx->device, src->data, pinned_buf, pinned_offset);
|
|
13723
|
+
if (pinned_buf == nullptr) {
|
|
13724
|
+
return false;
|
|
13725
|
+
}
|
|
13726
|
+
|
|
13727
|
+
vk_context cpy_ctx;
|
|
13728
|
+
if (ctx->device->async_use_transfer_queue) {
|
|
13729
|
+
if (ctx->transfer_ctx.expired()) {
|
|
13730
|
+
cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
|
|
13731
|
+
ctx->transfer_ctx = cpy_ctx;
|
|
13732
|
+
ggml_vk_ctx_begin(ctx->device, cpy_ctx);
|
|
13733
|
+
} else {
|
|
13734
|
+
cpy_ctx = ctx->transfer_ctx.lock();
|
|
13735
|
+
}
|
|
13736
|
+
} else {
|
|
13737
|
+
cpy_ctx = ggml_vk_get_compute_ctx(ctx);
|
|
13738
|
+
}
|
|
13739
|
+
|
|
13740
|
+
return ggml_vk_buffer_write_async(cpy_ctx, dst_buf,
|
|
13741
|
+
vk_tensor_offset(dst) + dst->view_offs,
|
|
13742
|
+
src->data, ggml_nbytes(src));
|
|
13743
|
+
}
|
|
13744
|
+
|
|
13745
|
+
GGML_UNUSED(backend_src);
|
|
13173
13746
|
return false;
|
|
13174
13747
|
}
|
|
13175
13748
|
|
|
13176
13749
|
static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
|
|
13177
13750
|
VK_LOG_DEBUG("ggml_vk_synchronize()");
|
|
13178
13751
|
|
|
13179
|
-
bool do_transfer = !ctx->
|
|
13752
|
+
bool do_transfer = !ctx->compute_ctx.expired();
|
|
13180
13753
|
|
|
13181
|
-
|
|
13754
|
+
if (ggml_vk_submit_transfer_ctx(ctx)) {
|
|
13755
|
+
ctx->submit_pending = true;
|
|
13756
|
+
}
|
|
13757
|
+
|
|
13758
|
+
vk_context compute_ctx;
|
|
13759
|
+
vk_command_buffer* cmd_buf = nullptr;
|
|
13182
13760
|
if (do_transfer) {
|
|
13183
|
-
|
|
13761
|
+
compute_ctx = ctx->compute_ctx.lock();
|
|
13762
|
+
if (compute_ctx->s) {
|
|
13763
|
+
cmd_buf = compute_ctx->s->buffer;
|
|
13764
|
+
}
|
|
13184
13765
|
|
|
13185
|
-
ggml_vk_ctx_end(
|
|
13766
|
+
ggml_vk_ctx_end(compute_ctx);
|
|
13186
13767
|
|
|
13187
|
-
for (auto& cpy :
|
|
13768
|
+
for (auto& cpy : compute_ctx->in_memcpys) {
|
|
13188
13769
|
memcpy(cpy.dst, cpy.src, cpy.n);
|
|
13189
13770
|
}
|
|
13190
13771
|
|
|
13191
|
-
ggml_vk_submit(
|
|
13772
|
+
ggml_vk_submit(compute_ctx, {});
|
|
13192
13773
|
ctx->submit_pending = true;
|
|
13193
13774
|
}
|
|
13194
13775
|
|
|
13195
13776
|
if (ctx->submit_pending) {
|
|
13196
|
-
{
|
|
13777
|
+
if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
|
|
13778
|
+
vk::TimelineSemaphoreSubmitInfo tl_info{
|
|
13779
|
+
1, &ctx->transfer_semaphore.value,
|
|
13780
|
+
0, nullptr,
|
|
13781
|
+
};
|
|
13782
|
+
vk::PipelineStageFlags stage = ctx->device->transfer_queue.stage_flags;
|
|
13783
|
+
vk::SubmitInfo si{
|
|
13784
|
+
1, &ctx->transfer_semaphore.s, &stage,
|
|
13785
|
+
0, nullptr,
|
|
13786
|
+
0, nullptr,
|
|
13787
|
+
};
|
|
13788
|
+
si.setPNext(&tl_info);
|
|
13789
|
+
std::lock_guard<std::mutex> guard(queue_mutex);
|
|
13790
|
+
ctx->device->compute_queue.queue.submit({ si }, ctx->fence);
|
|
13791
|
+
ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;
|
|
13792
|
+
} else {
|
|
13197
13793
|
std::lock_guard<std::mutex> guard(queue_mutex);
|
|
13198
13794
|
ctx->device->compute_queue.queue.submit({}, ctx->fence);
|
|
13199
13795
|
}
|
|
13200
13796
|
ggml_vk_wait_for_fence(ctx);
|
|
13201
13797
|
ctx->submit_pending = false;
|
|
13798
|
+
if (cmd_buf) {
|
|
13799
|
+
cmd_buf->in_use = false;
|
|
13800
|
+
}
|
|
13202
13801
|
}
|
|
13203
13802
|
|
|
13204
13803
|
if (do_transfer) {
|
|
13205
|
-
for (auto& cpy :
|
|
13804
|
+
for (auto& cpy : compute_ctx->out_memcpys) {
|
|
13206
13805
|
memcpy(cpy.dst, cpy.src, cpy.n);
|
|
13207
13806
|
}
|
|
13208
|
-
ctx->
|
|
13807
|
+
ctx->compute_ctx.reset();
|
|
13209
13808
|
}
|
|
13210
13809
|
}
|
|
13211
13810
|
|
|
@@ -13505,12 +14104,11 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
|
|
|
13505
14104
|
return true;
|
|
13506
14105
|
}
|
|
13507
14106
|
|
|
13508
|
-
// Check whether the tensors overlap in memory
|
|
13509
|
-
// Fusions can
|
|
13510
|
-
// by ggml-alloc. If the fusion is
|
|
13511
|
-
// to overlap if they are exactly equal.
|
|
13512
|
-
|
|
13513
|
-
static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) {
|
|
14107
|
+
// Check whether the tensors overlap in memory.
|
|
14108
|
+
// Fusions can potentially overwrite src tensors in ways that are not prevented
|
|
14109
|
+
// by ggml-alloc. If the fusion src is being applied in a way that's elementwise
|
|
14110
|
+
// with the destination, then it's OK for them to overlap if they are exactly equal.
|
|
14111
|
+
static bool ggml_vk_tensors_overlap(const ggml_tensor * a, const ggml_tensor * b, bool elementwise) {
|
|
13514
14112
|
ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;
|
|
13515
14113
|
vk_buffer a_buf = a_buf_ctx->dev_buffer;
|
|
13516
14114
|
ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;
|
|
@@ -13521,7 +14119,7 @@ static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const g
|
|
|
13521
14119
|
auto b_base = vk_tensor_offset(b) + b->view_offs;
|
|
13522
14120
|
auto b_size = ggml_nbytes(b);
|
|
13523
14121
|
|
|
13524
|
-
if (a_base == b_base && a_size == b_size) {
|
|
14122
|
+
if (elementwise && a_base == b_base && a_size == b_size) {
|
|
13525
14123
|
return false;
|
|
13526
14124
|
}
|
|
13527
14125
|
|
|
@@ -13559,13 +14157,6 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co
|
|
|
13559
14157
|
return false;
|
|
13560
14158
|
}
|
|
13561
14159
|
|
|
13562
|
-
// must not overwrite srcs in a way that's not elementwise
|
|
13563
|
-
ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
|
|
13564
|
-
if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) ||
|
|
13565
|
-
ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) {
|
|
13566
|
-
return false;
|
|
13567
|
-
}
|
|
13568
|
-
|
|
13569
14160
|
// conditions for pipeline creation
|
|
13570
14161
|
if (!(ctx->device->float_controls_rte_fp16 &&
|
|
13571
14162
|
sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
|
|
@@ -13627,6 +14218,18 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru
|
|
|
13627
14218
|
return num_adds;
|
|
13628
14219
|
}
|
|
13629
14220
|
|
|
14221
|
+
static int32_t find_first_set(uint32_t x) {
|
|
14222
|
+
int32_t ret = 0;
|
|
14223
|
+
if (!x) {
|
|
14224
|
+
return -1;
|
|
14225
|
+
}
|
|
14226
|
+
while (!(x & 1)) {
|
|
14227
|
+
x >>= 1;
|
|
14228
|
+
ret++;
|
|
14229
|
+
}
|
|
14230
|
+
return ret;
|
|
14231
|
+
}
|
|
14232
|
+
|
|
13630
14233
|
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
|
13631
14234
|
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
|
13632
14235
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
@@ -13645,7 +14248,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13645
14248
|
int last_node = cgraph->n_nodes - 1;
|
|
13646
14249
|
|
|
13647
14250
|
// If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
|
|
13648
|
-
while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) {
|
|
14251
|
+
while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) {
|
|
13649
14252
|
last_node -= 1;
|
|
13650
14253
|
}
|
|
13651
14254
|
|
|
@@ -13655,6 +14258,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13655
14258
|
bool first_node_in_batch = true; // true if next node will be first node in a batch
|
|
13656
14259
|
int submit_node_idx = 0; // index to first node in a batch
|
|
13657
14260
|
|
|
14261
|
+
ggml_vk_submit_transfer_ctx(ctx);
|
|
14262
|
+
|
|
13658
14263
|
vk_context compute_ctx;
|
|
13659
14264
|
if (vk_perf_logger_enabled) {
|
|
13660
14265
|
// allocate/resize the query pool
|
|
@@ -13680,11 +14285,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13680
14285
|
std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0);
|
|
13681
14286
|
|
|
13682
14287
|
GGML_ASSERT(ctx->compute_ctx.expired());
|
|
13683
|
-
compute_ctx =
|
|
13684
|
-
ctx->compute_ctx = compute_ctx;
|
|
13685
|
-
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
|
14288
|
+
compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
|
13686
14289
|
ctx->query_idx = 0;
|
|
13687
|
-
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
|
14290
|
+
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
|
13688
14291
|
}
|
|
13689
14292
|
|
|
13690
14293
|
ctx->prealloc_y_last_pipeline_used = nullptr;
|
|
@@ -13692,13 +14295,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13692
14295
|
|
|
13693
14296
|
if (ctx->prealloc_size_add_rms_partials) {
|
|
13694
14297
|
ggml_vk_preallocate_buffers(ctx, nullptr);
|
|
13695
|
-
|
|
13696
|
-
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
|
13697
|
-
ctx->compute_ctx = compute_ctx;
|
|
13698
|
-
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
|
13699
|
-
} else {
|
|
13700
|
-
compute_ctx = ctx->compute_ctx.lock();
|
|
13701
|
-
}
|
|
14298
|
+
compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
|
13702
14299
|
// initialize partial sums to zero.
|
|
13703
14300
|
ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
|
|
13704
14301
|
ggml_vk_sync_buffers(ctx, compute_ctx);
|
|
@@ -13725,6 +14322,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13725
14322
|
total_mul_mat_bytes += bytes;
|
|
13726
14323
|
}
|
|
13727
14324
|
|
|
14325
|
+
// op_srcs_fused_elementwise indicates whether an op's srcs all contribute to
|
|
14326
|
+
// the fused result in an elementwise-way. This affects whether the memory for
|
|
14327
|
+
// the src is allowed to overlap the memory for the destination.
|
|
14328
|
+
// The array is sized to handle the largest fusion (asserted later).
|
|
14329
|
+
bool op_srcs_fused_elementwise[12];
|
|
14330
|
+
|
|
13728
14331
|
ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
|
|
13729
14332
|
ctx->fused_topk_moe_scale = false;
|
|
13730
14333
|
const char *fusion_string {};
|
|
@@ -13733,39 +14336,68 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13733
14336
|
if (num_adds) {
|
|
13734
14337
|
ctx->num_additional_fused_ops = num_adds - 1;
|
|
13735
14338
|
fusion_string = "MULTI_ADD";
|
|
14339
|
+
std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, true);
|
|
13736
14340
|
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
|
|
13737
14341
|
ctx->num_additional_fused_ops = 2;
|
|
13738
14342
|
fusion_string = "MUL_MAT_ADD_ADD";
|
|
14343
|
+
op_srcs_fused_elementwise[0] = false;
|
|
14344
|
+
op_srcs_fused_elementwise[1] = true;
|
|
14345
|
+
op_srcs_fused_elementwise[2] = true;
|
|
13739
14346
|
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
|
|
13740
14347
|
ctx->num_additional_fused_ops = 1;
|
|
13741
14348
|
fusion_string = "MUL_MAT_ADD";
|
|
14349
|
+
op_srcs_fused_elementwise[0] = false;
|
|
14350
|
+
op_srcs_fused_elementwise[1] = true;
|
|
13742
14351
|
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
|
|
13743
14352
|
ctx->num_additional_fused_ops = 2;
|
|
13744
14353
|
fusion_string = "MUL_MAT_ID_ADD_ID_MUL";
|
|
14354
|
+
op_srcs_fused_elementwise[0] = false;
|
|
14355
|
+
op_srcs_fused_elementwise[1] = true;
|
|
14356
|
+
op_srcs_fused_elementwise[2] = true;
|
|
13745
14357
|
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
|
|
13746
14358
|
ctx->num_additional_fused_ops = 1;
|
|
13747
14359
|
fusion_string = "MUL_MAT_ID_ADD_ID";
|
|
14360
|
+
op_srcs_fused_elementwise[0] = false;
|
|
14361
|
+
op_srcs_fused_elementwise[1] = true;
|
|
13748
14362
|
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
|
|
13749
14363
|
ctx->num_additional_fused_ops = 1;
|
|
13750
14364
|
fusion_string = "MUL_MAT_ID_MUL";
|
|
14365
|
+
op_srcs_fused_elementwise[0] = false;
|
|
14366
|
+
op_srcs_fused_elementwise[1] = true;
|
|
13751
14367
|
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
|
|
13752
14368
|
ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
|
|
13753
14369
|
ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
|
|
13754
14370
|
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
|
|
13755
14371
|
ctx->num_additional_fused_ops = 4;
|
|
13756
14372
|
fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS";
|
|
14373
|
+
op_srcs_fused_elementwise[0] = false;
|
|
14374
|
+
op_srcs_fused_elementwise[1] = false;
|
|
14375
|
+
op_srcs_fused_elementwise[2] = false;
|
|
14376
|
+
op_srcs_fused_elementwise[3] = false;
|
|
14377
|
+
op_srcs_fused_elementwise[4] = false;
|
|
13757
14378
|
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
|
|
13758
14379
|
ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
|
|
13759
14380
|
ctx->num_additional_fused_ops = 2;
|
|
13760
14381
|
fusion_string = "RMS_NORM_MUL_ROPE";
|
|
14382
|
+
// rope is approximately elementwise - whole rows are done by a single workgroup and it's row-wise
|
|
14383
|
+
op_srcs_fused_elementwise[0] = false;
|
|
14384
|
+
op_srcs_fused_elementwise[1] = true;
|
|
14385
|
+
op_srcs_fused_elementwise[2] = true;
|
|
13761
14386
|
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
13762
14387
|
ctx->num_additional_fused_ops = 1;
|
|
13763
14388
|
fusion_string = "RMS_NORM_MUL";
|
|
14389
|
+
// rms_norm is not elementwise, but whole rows must be consumed and the scale factor computed before
|
|
14390
|
+
// they are overwritten, and one workgroup per row. So close enough.
|
|
14391
|
+
op_srcs_fused_elementwise[0] = true;
|
|
14392
|
+
op_srcs_fused_elementwise[1] = true;
|
|
13764
14393
|
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
|
|
13765
14394
|
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
|
|
13766
14395
|
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
|
|
13767
14396
|
ctx->num_additional_fused_ops = 2;
|
|
13768
14397
|
fusion_string = "ROPE_VIEW_SET_ROWS";
|
|
14398
|
+
op_srcs_fused_elementwise[0] = false;
|
|
14399
|
+
op_srcs_fused_elementwise[1] = false;
|
|
14400
|
+
op_srcs_fused_elementwise[2] = false;
|
|
13769
14401
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
|
|
13770
14402
|
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
|
|
13771
14403
|
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
|
|
@@ -13774,6 +14406,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13774
14406
|
ctx->fused_ops_write_mask |= 1 << 3;
|
|
13775
14407
|
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
|
|
13776
14408
|
fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
|
|
14409
|
+
std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
|
|
13777
14410
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
|
|
13778
14411
|
ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
|
|
13779
14412
|
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
|
|
@@ -13782,6 +14415,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13782
14415
|
ctx->fused_ops_write_mask |= 1 << 4;
|
|
13783
14416
|
ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
|
|
13784
14417
|
fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
|
|
14418
|
+
std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
|
|
13785
14419
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
|
13786
14420
|
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
|
13787
14421
|
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
|
@@ -13790,6 +14424,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13790
14424
|
ctx->fused_ops_write_mask |= 1 << 3;
|
|
13791
14425
|
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
|
|
13792
14426
|
fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
|
|
14427
|
+
std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
|
|
13793
14428
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
|
13794
14429
|
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
|
13795
14430
|
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
|
|
@@ -13798,6 +14433,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13798
14433
|
ctx->fused_ops_write_mask |= 1 << 1;
|
|
13799
14434
|
ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
|
|
13800
14435
|
fusion_string = "TOPK_MOE_LATE_SOFTMAX";
|
|
14436
|
+
std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
|
|
13801
14437
|
}
|
|
13802
14438
|
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
|
13803
14439
|
// Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
|
|
@@ -13805,11 +14441,73 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13805
14441
|
ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
|
|
13806
14442
|
ctx->fused_topk_moe_scale = true;
|
|
13807
14443
|
ctx->num_additional_fused_ops++;
|
|
14444
|
+
op_srcs_fused_elementwise[ctx->num_additional_fused_ops] = false;
|
|
13808
14445
|
}
|
|
13809
14446
|
}
|
|
13810
14447
|
}
|
|
14448
|
+
GGML_ASSERT(ctx->num_additional_fused_ops < (int)(sizeof(op_srcs_fused_elementwise) / sizeof(op_srcs_fused_elementwise[0])));
|
|
13811
14449
|
ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
|
|
13812
14450
|
|
|
14451
|
+
// Check whether fusion would overwrite src operands while they're still in use.
|
|
14452
|
+
// If so, disable fusion.
|
|
14453
|
+
if (ctx->num_additional_fused_ops) {
|
|
14454
|
+
// There are up to two output nodes - topk_moe has two.
|
|
14455
|
+
uint32_t bits = ctx->fused_ops_write_mask & ~(1 << ctx->num_additional_fused_ops);
|
|
14456
|
+
ggml_tensor *output_nodes[2] {};
|
|
14457
|
+
output_nodes[0] = cgraph->nodes[i + ctx->num_additional_fused_ops];
|
|
14458
|
+
if (bits) {
|
|
14459
|
+
int output_idx = find_first_set(bits);
|
|
14460
|
+
GGML_ASSERT(bits == (1u << output_idx));
|
|
14461
|
+
output_nodes[1] = cgraph->nodes[i + output_idx];
|
|
14462
|
+
}
|
|
14463
|
+
|
|
14464
|
+
bool need_disable = false;
|
|
14465
|
+
|
|
14466
|
+
// topk_moe often overwrites the source, but for a given row all the src values are
|
|
14467
|
+
// loaded before anything is stored. If there's only one row, this is safe, so treat
|
|
14468
|
+
// this as a special case.
|
|
14469
|
+
bool is_topk_moe_single_row = ctx->fused_topk_moe_mode != TOPK_MOE_COUNT &&
|
|
14470
|
+
ggml_nrows(cgraph->nodes[i]->src[0]) == 1;
|
|
14471
|
+
|
|
14472
|
+
if (!is_topk_moe_single_row) {
|
|
14473
|
+
for (int j = 0; j < 2; ++j) {
|
|
14474
|
+
ggml_tensor *dst = output_nodes[j];
|
|
14475
|
+
if (!dst) {
|
|
14476
|
+
continue;
|
|
14477
|
+
}
|
|
14478
|
+
// Loop over all srcs of all nodes in the fusion. If the src overlaps
|
|
14479
|
+
// the destination and the src is not an intermediate node that's being
|
|
14480
|
+
// elided, then disable fusion.
|
|
14481
|
+
for (int k = 0; k <= ctx->num_additional_fused_ops; ++k) {
|
|
14482
|
+
for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
|
|
14483
|
+
ggml_tensor *src = cgraph->nodes[i + k]->src[s];
|
|
14484
|
+
if (!src || src->op == GGML_OP_NONE) {
|
|
14485
|
+
continue;
|
|
14486
|
+
}
|
|
14487
|
+
if (ggml_vk_tensors_overlap(src, dst, op_srcs_fused_elementwise[k])) {
|
|
14488
|
+
bool found = false;
|
|
14489
|
+
for (int n = 0; n < k; ++n) {
|
|
14490
|
+
if (cgraph->nodes[i + n] == src) {
|
|
14491
|
+
found = true;
|
|
14492
|
+
break;
|
|
14493
|
+
}
|
|
14494
|
+
}
|
|
14495
|
+
if (!found) {
|
|
14496
|
+
need_disable = true;
|
|
14497
|
+
}
|
|
14498
|
+
}
|
|
14499
|
+
}
|
|
14500
|
+
}
|
|
14501
|
+
}
|
|
14502
|
+
}
|
|
14503
|
+
if (need_disable) {
|
|
14504
|
+
ctx->num_additional_fused_ops = 0;
|
|
14505
|
+
ctx->fused_ops_write_mask = 1;
|
|
14506
|
+
ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
|
|
14507
|
+
ctx->fused_topk_moe_scale = false;
|
|
14508
|
+
}
|
|
14509
|
+
}
|
|
14510
|
+
|
|
13813
14511
|
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
|
|
13814
14512
|
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
|
13815
14513
|
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
|
@@ -13820,18 +14518,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13820
14518
|
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
|
|
13821
14519
|
|
|
13822
14520
|
if (vk_perf_logger_enabled && enqueued) {
|
|
13823
|
-
|
|
13824
|
-
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
|
13825
|
-
ctx->compute_ctx = compute_ctx;
|
|
13826
|
-
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
|
13827
|
-
} else {
|
|
13828
|
-
compute_ctx = ctx->compute_ctx.lock();
|
|
13829
|
-
}
|
|
14521
|
+
compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
|
13830
14522
|
if (!vk_perf_logger_concurrent) {
|
|
13831
14523
|
// track a single node/fusion for the current query
|
|
13832
14524
|
ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
|
|
13833
14525
|
ctx->query_fusion_names[ctx->query_idx] = fusion_string;
|
|
13834
|
-
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
|
14526
|
+
compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
|
|
13835
14527
|
} else {
|
|
13836
14528
|
// track a fusion string and number of fused ops for the current node_idx
|
|
13837
14529
|
ctx->query_fusion_names[i] = fusion_string;
|
|
@@ -13874,6 +14566,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13874
14566
|
ggml_vk_submit(compute_ctx, ctx->device->fence);
|
|
13875
14567
|
VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences");
|
|
13876
14568
|
ctx->device->device.resetFences({ ctx->device->fence });
|
|
14569
|
+
ctx->compute_ctx.reset();
|
|
13877
14570
|
|
|
13878
14571
|
// Get the results and pass them to the logger
|
|
13879
14572
|
std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
|
|
@@ -14160,29 +14853,24 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
|
|
|
14160
14853
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
14161
14854
|
vk_event *vkev = (vk_event *)event->context;
|
|
14162
14855
|
|
|
14163
|
-
|
|
14856
|
+
ggml_vk_submit_transfer_ctx(ctx);
|
|
14164
14857
|
|
|
14165
|
-
|
|
14166
|
-
|
|
14167
|
-
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
|
14168
|
-
ctx->transfer_ctx = transfer_ctx;
|
|
14169
|
-
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
|
|
14170
|
-
} else {
|
|
14171
|
-
transfer_ctx = ctx->transfer_ctx.lock();
|
|
14172
|
-
}
|
|
14858
|
+
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
|
14859
|
+
auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset
|
|
14173
14860
|
|
|
14174
14861
|
// the backend interface doesn't have an explicit reset, so reset it here
|
|
14175
14862
|
// before we record the command to set it
|
|
14176
14863
|
ctx->device->device.resetEvent(vkev->event);
|
|
14177
14864
|
ctx->device->device.resetFences({ vkev->fence });
|
|
14178
14865
|
|
|
14179
|
-
ggml_vk_set_event(
|
|
14866
|
+
ggml_vk_set_event(compute_ctx, vkev->event);
|
|
14180
14867
|
|
|
14181
|
-
ggml_vk_ctx_end(
|
|
14868
|
+
ggml_vk_ctx_end(compute_ctx);
|
|
14182
14869
|
|
|
14183
|
-
ggml_vk_submit(
|
|
14870
|
+
ggml_vk_submit(compute_ctx, {vkev->fence});
|
|
14184
14871
|
ctx->submit_pending = true;
|
|
14185
|
-
|
|
14872
|
+
vkev->cmd_buffer = cmd_buf;
|
|
14873
|
+
ctx->compute_ctx.reset();
|
|
14186
14874
|
}
|
|
14187
14875
|
|
|
14188
14876
|
static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
|
|
@@ -14190,20 +14878,11 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even
|
|
|
14190
14878
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
14191
14879
|
vk_event *vkev = (vk_event *)event->context;
|
|
14192
14880
|
|
|
14193
|
-
vk_context
|
|
14194
|
-
|
|
14195
|
-
if (ctx->transfer_ctx.expired()) {
|
|
14196
|
-
// Initialize new transfer context
|
|
14197
|
-
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
|
14198
|
-
ctx->transfer_ctx = transfer_ctx;
|
|
14199
|
-
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
|
|
14200
|
-
} else {
|
|
14201
|
-
transfer_ctx = ctx->transfer_ctx.lock();
|
|
14202
|
-
}
|
|
14881
|
+
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
|
|
14203
14882
|
|
|
14204
|
-
ggml_vk_wait_events(
|
|
14205
|
-
ggml_vk_ctx_end(
|
|
14206
|
-
ctx->
|
|
14883
|
+
ggml_vk_wait_events(compute_ctx, {vkev->event});
|
|
14884
|
+
ggml_vk_ctx_end(compute_ctx);
|
|
14885
|
+
ctx->compute_ctx.reset();
|
|
14207
14886
|
}
|
|
14208
14887
|
|
|
14209
14888
|
// TODO: enable async and synchronize
|
|
@@ -14212,7 +14891,7 @@ static ggml_backend_i ggml_backend_vk_interface = {
|
|
|
14212
14891
|
/* .free = */ ggml_backend_vk_free,
|
|
14213
14892
|
/* .set_tensor_async = */ ggml_backend_vk_set_tensor_async,
|
|
14214
14893
|
/* .get_tensor_async = */ ggml_backend_vk_get_tensor_async,
|
|
14215
|
-
/* .cpy_tensor_async = */
|
|
14894
|
+
/* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async,
|
|
14216
14895
|
/* .synchronize = */ ggml_backend_vk_synchronize,
|
|
14217
14896
|
/* .graph_plan_create = */ NULL,
|
|
14218
14897
|
/* .graph_plan_free = */ NULL,
|
|
@@ -14413,13 +15092,29 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
14413
15092
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
14414
15093
|
const vk_device& device = ggml_vk_get_device(ctx->device);
|
|
14415
15094
|
|
|
15095
|
+
const bool uses_bda = (op->op == GGML_OP_IM2COL || op->op == GGML_OP_IM2COL_3D) &&
|
|
15096
|
+
device->shader_int64 && device->buffer_device_address;
|
|
15097
|
+
|
|
15098
|
+
auto const & tensor_size_supported = [&](size_t tensor_size) {
|
|
15099
|
+
if (tensor_size > device->max_buffer_size) {
|
|
15100
|
+
return false;
|
|
15101
|
+
}
|
|
15102
|
+
// For im2col shaders using BDA, maxStorageBufferRange limit doesn't apply.
|
|
15103
|
+
// If shader64BitIndexing is enabled, maxStorageBufferRange limit doesn't apply.
|
|
15104
|
+
if (!uses_bda && !device->shader_64b_indexing) {
|
|
15105
|
+
if (tensor_size > device->properties.limits.maxStorageBufferRange) {
|
|
15106
|
+
return false;
|
|
15107
|
+
}
|
|
15108
|
+
}
|
|
15109
|
+
return true;
|
|
15110
|
+
};
|
|
14416
15111
|
// reject any tensors larger than the max buffer size
|
|
14417
15112
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
14418
|
-
if (op->src[i] && ggml_nbytes(op->src[i])
|
|
15113
|
+
if (op->src[i] && !tensor_size_supported(ggml_nbytes(op->src[i]))) {
|
|
14419
15114
|
return false;
|
|
14420
15115
|
}
|
|
14421
15116
|
}
|
|
14422
|
-
if (ggml_nbytes(op)
|
|
15117
|
+
if (!tensor_size_supported(ggml_nbytes(op))) {
|
|
14423
15118
|
return false;
|
|
14424
15119
|
}
|
|
14425
15120
|
|
|
@@ -14427,6 +15122,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
14427
15122
|
case GGML_OP_UNARY:
|
|
14428
15123
|
switch (ggml_get_unary_op(op)) {
|
|
14429
15124
|
case GGML_UNARY_OP_EXP:
|
|
15125
|
+
case GGML_UNARY_OP_ELU:
|
|
14430
15126
|
case GGML_UNARY_OP_GELU:
|
|
14431
15127
|
case GGML_UNARY_OP_GELU_ERF:
|
|
14432
15128
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
@@ -14445,6 +15141,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
14445
15141
|
case GGML_UNARY_OP_CEIL:
|
|
14446
15142
|
case GGML_UNARY_OP_FLOOR:
|
|
14447
15143
|
case GGML_UNARY_OP_TRUNC:
|
|
15144
|
+
case GGML_UNARY_OP_SGN:
|
|
14448
15145
|
return ggml_is_contiguous(op->src[0]) &&
|
|
14449
15146
|
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
|
14450
15147
|
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
|
@@ -14707,6 +15404,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
14707
15404
|
case GGML_OP_REPEAT_BACK:
|
|
14708
15405
|
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
|
14709
15406
|
case GGML_OP_ROPE:
|
|
15407
|
+
return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]);
|
|
14710
15408
|
case GGML_OP_ROPE_BACK:
|
|
14711
15409
|
case GGML_OP_NONE:
|
|
14712
15410
|
case GGML_OP_RESHAPE:
|
|
@@ -14717,8 +15415,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
14717
15415
|
return true;
|
|
14718
15416
|
case GGML_OP_NORM:
|
|
14719
15417
|
case GGML_OP_GROUP_NORM:
|
|
14720
|
-
case GGML_OP_L2_NORM:
|
|
14721
15418
|
return ggml_is_contiguous(op->src[0]);
|
|
15419
|
+
case GGML_OP_L2_NORM:
|
|
15420
|
+
return ggml_is_contiguous_rows(op->src[0]) &&
|
|
15421
|
+
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
|
14722
15422
|
case GGML_OP_ADD:
|
|
14723
15423
|
case GGML_OP_SUB:
|
|
14724
15424
|
case GGML_OP_MUL:
|
|
@@ -14781,7 +15481,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
14781
15481
|
}
|
|
14782
15482
|
return op->src[0]->type == GGML_TYPE_F32;
|
|
14783
15483
|
case GGML_OP_ACC:
|
|
14784
|
-
return op->src[0]->type == GGML_TYPE_F32;
|
|
15484
|
+
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
|
15485
|
+
case GGML_OP_SET:
|
|
15486
|
+
return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
|
|
15487
|
+
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
|
|
14785
15488
|
case GGML_OP_CONCAT:
|
|
14786
15489
|
return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
|
|
14787
15490
|
case GGML_OP_ADD1:
|
|
@@ -14855,6 +15558,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
14855
15558
|
case GGML_OP_RWKV_WKV6:
|
|
14856
15559
|
case GGML_OP_RWKV_WKV7:
|
|
14857
15560
|
return true; // all inputs are contiguous, see ggml.c
|
|
15561
|
+
case GGML_OP_GATED_DELTA_NET:
|
|
15562
|
+
{
|
|
15563
|
+
const uint32_t S_v = op->src[2]->ne[0];
|
|
15564
|
+
if (S_v != 32 && S_v != 64 && S_v != 128) {
|
|
15565
|
+
return false;
|
|
15566
|
+
}
|
|
15567
|
+
for (int i = 0; i < 6; i++) {
|
|
15568
|
+
if (op->src[i] == nullptr || op->src[i]->type != GGML_TYPE_F32) {
|
|
15569
|
+
return false;
|
|
15570
|
+
}
|
|
15571
|
+
}
|
|
15572
|
+
return op->type == GGML_TYPE_F32;
|
|
15573
|
+
}
|
|
14858
15574
|
case GGML_OP_SSM_SCAN:
|
|
14859
15575
|
{
|
|
14860
15576
|
for (int i = 0; i < 6; i++) {
|
|
@@ -14926,11 +15642,25 @@ static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_ba
|
|
|
14926
15642
|
return buft_ctx->device->idx == ctx->device;
|
|
14927
15643
|
}
|
|
14928
15644
|
|
|
15645
|
+
static int64_t ggml_vk_get_op_batch_size(const ggml_tensor * op) {
|
|
15646
|
+
switch (op->op) {
|
|
15647
|
+
case GGML_OP_GET_ROWS:
|
|
15648
|
+
return 0;
|
|
15649
|
+
case GGML_OP_MUL_MAT:
|
|
15650
|
+
return op->ne[1];
|
|
15651
|
+
case GGML_OP_MUL_MAT_ID:
|
|
15652
|
+
case GGML_OP_ROPE:
|
|
15653
|
+
case GGML_OP_ROPE_BACK:
|
|
15654
|
+
return op->ne[2];
|
|
15655
|
+
default:
|
|
15656
|
+
return ggml_nrows(op);
|
|
15657
|
+
}
|
|
15658
|
+
}
|
|
15659
|
+
|
|
14929
15660
|
static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
|
14930
15661
|
ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
14931
15662
|
|
|
14932
|
-
return (op
|
|
14933
|
-
(op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
|
|
15663
|
+
return ggml_vk_get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;
|
|
14934
15664
|
}
|
|
14935
15665
|
|
|
14936
15666
|
static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {
|
|
@@ -14972,6 +15702,10 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm
|
|
|
14972
15702
|
vk_event *vkev = (vk_event *)event->context;
|
|
14973
15703
|
|
|
14974
15704
|
VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
|
|
15705
|
+
// Finished using current command buffer so we flag for reuse
|
|
15706
|
+
if (vkev->cmd_buffer) {
|
|
15707
|
+
vkev->cmd_buffer->in_use = false;
|
|
15708
|
+
}
|
|
14975
15709
|
}
|
|
14976
15710
|
|
|
14977
15711
|
static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {
|
|
@@ -15190,6 +15924,46 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
|
|
|
15190
15924
|
}
|
|
15191
15925
|
}
|
|
15192
15926
|
|
|
15927
|
+
static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) {
|
|
15928
|
+
VkPhysicalDeviceProperties2 props = vkdev.getProperties2();
|
|
15929
|
+
|
|
15930
|
+
if (props.properties.vendorID != VK_VENDOR_ID_INTEL) {
|
|
15931
|
+
return 0;
|
|
15932
|
+
}
|
|
15933
|
+
|
|
15934
|
+
const uint32_t device_id = props.properties.deviceID;
|
|
15935
|
+
|
|
15936
|
+
switch (device_id) {
|
|
15937
|
+
case 0x56A6: // A310
|
|
15938
|
+
return 6;
|
|
15939
|
+
case 0x5693: // A370M
|
|
15940
|
+
case 0x56A5: // A380
|
|
15941
|
+
case 0x56B1: // Pro A40/A50
|
|
15942
|
+
return 8;
|
|
15943
|
+
case 0x5697: // A530M
|
|
15944
|
+
return 12;
|
|
15945
|
+
case 0x5692: // A550M
|
|
15946
|
+
case 0x56B3: // Pro A60
|
|
15947
|
+
return 16;
|
|
15948
|
+
case 0x56A2: // A580
|
|
15949
|
+
return 24;
|
|
15950
|
+
case 0x5691: // A730M
|
|
15951
|
+
case 0x56A1: // A750
|
|
15952
|
+
return 28;
|
|
15953
|
+
case 0x56A0: // A770
|
|
15954
|
+
case 0x5690: // A770M
|
|
15955
|
+
return 32;
|
|
15956
|
+
case 0xE212: // Pro B50
|
|
15957
|
+
return 16;
|
|
15958
|
+
case 0xE20C: // B570
|
|
15959
|
+
return 18;
|
|
15960
|
+
case 0xE20B: // B580
|
|
15961
|
+
return 20;
|
|
15962
|
+
default:
|
|
15963
|
+
return 0;
|
|
15964
|
+
}
|
|
15965
|
+
}
|
|
15966
|
+
|
|
15193
15967
|
// checks
|
|
15194
15968
|
|
|
15195
15969
|
#ifdef GGML_VULKAN_CHECK_RESULTS
|
|
@@ -15403,7 +16177,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
15403
16177
|
tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
|
|
15404
16178
|
} else if (tensor->op == GGML_OP_FILL) {
|
|
15405
16179
|
const float value = ggml_get_op_params_f32(tensor, 0);
|
|
15406
|
-
tensor_clone = ggml_fill(ggml_ctx,
|
|
16180
|
+
tensor_clone = ggml_fill(ggml_ctx, src_clone[0], value);
|
|
15407
16181
|
} else if (tensor->op == GGML_OP_SQR) {
|
|
15408
16182
|
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
|
|
15409
16183
|
} else if (tensor->op == GGML_OP_SQRT) {
|
|
@@ -15432,6 +16206,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
15432
16206
|
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
|
|
15433
16207
|
} else if (tensor->op == GGML_OP_ACC) {
|
|
15434
16208
|
tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
|
|
16209
|
+
} else if (tensor->op == GGML_OP_SET) {
|
|
16210
|
+
tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
|
|
15435
16211
|
} else if (tensor->op == GGML_OP_NORM) {
|
|
15436
16212
|
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
|
15437
16213
|
} else if (tensor->op == GGML_OP_GROUP_NORM) {
|
|
@@ -15488,6 +16264,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
15488
16264
|
case GGML_UNARY_OP_EXP:
|
|
15489
16265
|
tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
|
|
15490
16266
|
break;
|
|
16267
|
+
case GGML_UNARY_OP_ELU:
|
|
16268
|
+
tensor_clone = ggml_elu(ggml_ctx, src_clone[0]);
|
|
16269
|
+
break;
|
|
15491
16270
|
case GGML_UNARY_OP_SILU:
|
|
15492
16271
|
tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
|
|
15493
16272
|
break;
|
|
@@ -15546,6 +16325,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
15546
16325
|
case GGML_UNARY_OP_TRUNC:
|
|
15547
16326
|
tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
|
|
15548
16327
|
break;
|
|
16328
|
+
case GGML_UNARY_OP_SGN:
|
|
16329
|
+
tensor_clone = ggml_sgn(ggml_ctx, src_clone[0]);
|
|
16330
|
+
break;
|
|
15549
16331
|
default:
|
|
15550
16332
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
|
15551
16333
|
GGML_ABORT("fatal error");
|
|
@@ -15666,6 +16448,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
15666
16448
|
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
|
|
15667
16449
|
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
|
|
15668
16450
|
src_clone[4], src_clone[5], src_clone[6]);
|
|
16451
|
+
} else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
|
|
16452
|
+
tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
|
|
16453
|
+
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
|
|
15669
16454
|
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
|
15670
16455
|
src_clone[0]->flags = tensor->src[0]->flags;
|
|
15671
16456
|
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
|
@@ -15864,7 +16649,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
15864
16649
|
ggml_vk_print_graph_origin(tensor, done);
|
|
15865
16650
|
}
|
|
15866
16651
|
|
|
15867
|
-
if (avg_err > 0.
|
|
16652
|
+
if (avg_err > 0.01 || std::isnan(avg_err)) {
|
|
15868
16653
|
std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
|
|
15869
16654
|
std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
|
|
15870
16655
|
if (src0 != nullptr) {
|