llama-cpp-pydist 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_cpp/binaries/{llama-b7488-bin-win-cpu-x64.zip → llama-b7631-bin-win-cpu-x64.zip} +0 -0
- llama_cpp_pydist-0.21.0.dist-info/METADATA +4684 -0
- {llama_cpp_pydist-0.19.0.dist-info → llama_cpp_pydist-0.21.0.dist-info}/RECORD +240 -222
- {llama_cpp_pydist-0.19.0.dist-info → llama_cpp_pydist-0.21.0.dist-info}/WHEEL +1 -1
- vendor_llama_cpp_pydist/llama.cpp/.devops/cuda-new.Dockerfile +95 -0
- vendor_llama_cpp_pydist/llama.cpp/.gemini/settings.json +1 -0
- vendor_llama_cpp_pydist/llama.cpp/.github/ISSUE_TEMPLATE/010-bug-compilation.yml +2 -1
- vendor_llama_cpp_pydist/llama.cpp/.github/ISSUE_TEMPLATE/011-bug-results.yml +13 -2
- vendor_llama_cpp_pydist/llama.cpp/.github/ISSUE_TEMPLATE/019-bug-misc.yml +13 -2
- vendor_llama_cpp_pydist/llama.cpp/.github/workflows/build.yml +18 -6
- vendor_llama_cpp_pydist/llama.cpp/.github/workflows/docker.yml +25 -13
- vendor_llama_cpp_pydist/llama.cpp/.github/workflows/release.yml +9 -5
- vendor_llama_cpp_pydist/llama.cpp/.github/workflows/server.yml +18 -0
- vendor_llama_cpp_pydist/llama.cpp/AGENTS.md +81 -0
- vendor_llama_cpp_pydist/llama.cpp/CLAUDE.md +1 -0
- vendor_llama_cpp_pydist/llama.cpp/CONTRIBUTING.md +34 -5
- vendor_llama_cpp_pydist/llama.cpp/ci/run.sh +2 -1
- vendor_llama_cpp_pydist/llama.cpp/common/CMakeLists.txt +4 -3
- vendor_llama_cpp_pydist/llama.cpp/common/arg.cpp +46 -14
- vendor_llama_cpp_pydist/llama.cpp/common/arg.h +1 -0
- vendor_llama_cpp_pydist/llama.cpp/common/chat-parser.cpp +11 -0
- vendor_llama_cpp_pydist/llama.cpp/common/chat.cpp +36 -7
- vendor_llama_cpp_pydist/llama.cpp/common/chat.h +1 -0
- vendor_llama_cpp_pydist/llama.cpp/common/common.cpp +42 -23
- vendor_llama_cpp_pydist/llama.cpp/common/common.h +7 -2
- vendor_llama_cpp_pydist/llama.cpp/common/llguidance.cpp +10 -6
- vendor_llama_cpp_pydist/llama.cpp/common/regex-partial.cpp +13 -13
- vendor_llama_cpp_pydist/llama.cpp/common/sampling.cpp +58 -14
- vendor_llama_cpp_pydist/llama.cpp/common/sampling.h +3 -1
- vendor_llama_cpp_pydist/llama.cpp/convert_hf_to_gguf.py +424 -103
- vendor_llama_cpp_pydist/llama.cpp/convert_hf_to_gguf_update.py +5 -0
- vendor_llama_cpp_pydist/llama.cpp/docs/backend/CANN.md +4 -0
- vendor_llama_cpp_pydist/llama.cpp/docs/backend/OPENCL.md +51 -1
- vendor_llama_cpp_pydist/llama.cpp/docs/backend/SYCL.md +1 -1
- vendor_llama_cpp_pydist/llama.cpp/docs/backend/hexagon/README.md +5 -5
- vendor_llama_cpp_pydist/llama.cpp/docs/backend/hexagon/developer.md +1 -1
- vendor_llama_cpp_pydist/llama.cpp/docs/build.md +21 -2
- vendor_llama_cpp_pydist/llama.cpp/docs/development/parsing.md +2 -2
- vendor_llama_cpp_pydist/llama.cpp/docs/ops/Metal.csv +360 -322
- vendor_llama_cpp_pydist/llama.cpp/docs/ops.md +1 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/CMakeLists.txt +13 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/include/ggml-backend.h +1 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/CMakeLists.txt +23 -9
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-backend.cpp +11 -11
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +303 -19
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +17 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cann/common.h +153 -9
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +51 -158
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +12 -2
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +1 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +86 -25
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +15 -8
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +768 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +0 -4
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +66 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/argsort.cu +48 -27
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/argsort.cuh +16 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/common.cuh +45 -9
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/cpy.cu +117 -103
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/cumsum.cu +105 -35
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +3 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +2 -2
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +83 -33
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/mean.cu +3 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/mma.cuh +21 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/mmq.cu +34 -8
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +168 -13
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/quantize.cu +151 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/quantize.cuh +14 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/softmax.cu +203 -6
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/top-k.cu +96 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/top-k.cuh +3 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu +17 -2
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh +6 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +224 -758
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp/act-ops.c +316 -164
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp/htp-dma.c +5 -11
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp/htp-dma.h +46 -15
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +9 -3
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp/main.c +2 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +20 -20
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp-utils.h +1 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/op-desc.h +153 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-impl.h +0 -4
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +57 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h +2 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m +5 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +20 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +71 -2
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h +1 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +73 -6
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +134 -13
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-opencl/kernels/cvt.cl +21 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +14 -7
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +42 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +742 -315
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +28 -14
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +1 -7
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl +2 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +17 -4
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +42 -24
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +11 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +115 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +10 -4
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +29 -18
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +19 -16
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +10 -4
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +8 -8
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +11 -4
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +4 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +4 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +4 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +1 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +4 -1
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +57 -22
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +312 -6
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +54 -0
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -2
- vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
- vendor_llama_cpp_pydist/llama.cpp/gguf-py/gguf/constants.py +99 -0
- vendor_llama_cpp_pydist/llama.cpp/gguf-py/gguf/gguf_writer.py +38 -2
- vendor_llama_cpp_pydist/llama.cpp/gguf-py/gguf/tensor_mapping.py +26 -0
- vendor_llama_cpp_pydist/llama.cpp/gguf-py/gguf/utility.py +0 -8
- vendor_llama_cpp_pydist/llama.cpp/grammars/README.md +3 -0
- vendor_llama_cpp_pydist/llama.cpp/include/llama.h +99 -12
- vendor_llama_cpp_pydist/llama.cpp/scripts/snapdragon/adb/run-cli.sh +9 -9
- vendor_llama_cpp_pydist/llama.cpp/scripts/snapdragon/adb/run-completion.sh +53 -0
- vendor_llama_cpp_pydist/llama.cpp/scripts/sync-ggml.last +1 -1
- vendor_llama_cpp_pydist/llama.cpp/src/CMakeLists.txt +4 -0
- vendor_llama_cpp_pydist/llama.cpp/src/llama-adapter.cpp +12 -3
- vendor_llama_cpp_pydist/llama.cpp/src/llama-adapter.h +7 -1
- vendor_llama_cpp_pydist/llama.cpp/src/llama-arch.cpp +76 -0
- vendor_llama_cpp_pydist/llama.cpp/src/llama-arch.h +7 -0
- vendor_llama_cpp_pydist/llama.cpp/src/llama-chat.cpp +11 -0
- vendor_llama_cpp_pydist/llama.cpp/src/llama-chat.h +1 -0
- vendor_llama_cpp_pydist/llama.cpp/src/llama-context.cpp +625 -40
- vendor_llama_cpp_pydist/llama.cpp/src/llama-context.h +43 -1
- vendor_llama_cpp_pydist/llama.cpp/src/llama-grammar.cpp +40 -13
- vendor_llama_cpp_pydist/llama.cpp/src/llama-grammar.h +2 -0
- vendor_llama_cpp_pydist/llama.cpp/src/llama-graph.cpp +166 -2
- vendor_llama_cpp_pydist/llama.cpp/src/llama-graph.h +71 -6
- vendor_llama_cpp_pydist/llama.cpp/src/llama-hparams.h +6 -5
- vendor_llama_cpp_pydist/llama.cpp/src/llama-kv-cache.h +1 -1
- vendor_llama_cpp_pydist/llama.cpp/src/llama-mmap.cpp +11 -4
- vendor_llama_cpp_pydist/llama.cpp/src/llama-model-loader.cpp +23 -0
- vendor_llama_cpp_pydist/llama.cpp/src/llama-model-loader.h +2 -0
- vendor_llama_cpp_pydist/llama.cpp/src/llama-model.cpp +329 -26
- vendor_llama_cpp_pydist/llama.cpp/src/llama-model.h +13 -2
- vendor_llama_cpp_pydist/llama.cpp/src/llama-sampling.cpp +1259 -186
- vendor_llama_cpp_pydist/llama.cpp/src/llama-sampling.h +19 -7
- vendor_llama_cpp_pydist/llama.cpp/src/llama-vocab.cpp +101 -33
- vendor_llama_cpp_pydist/llama.cpp/src/llama-vocab.h +2 -0
- vendor_llama_cpp_pydist/llama.cpp/src/llama.cpp +53 -38
- vendor_llama_cpp_pydist/llama.cpp/src/models/afmoe.cpp +9 -5
- vendor_llama_cpp_pydist/llama.cpp/src/models/bert.cpp +4 -2
- vendor_llama_cpp_pydist/llama.cpp/src/models/cogvlm.cpp +5 -3
- vendor_llama_cpp_pydist/llama.cpp/src/models/cohere2-iswa.cpp +3 -0
- vendor_llama_cpp_pydist/llama.cpp/src/models/deepseek2.cpp +1 -1
- vendor_llama_cpp_pydist/llama.cpp/src/models/gemma-embedding.cpp +2 -6
- vendor_llama_cpp_pydist/llama.cpp/src/models/gemma2-iswa.cpp +5 -2
- vendor_llama_cpp_pydist/llama.cpp/src/models/gemma3.cpp +3 -4
- vendor_llama_cpp_pydist/llama.cpp/src/models/gemma3n-iswa.cpp +4 -7
- vendor_llama_cpp_pydist/llama.cpp/src/models/llama-iswa.cpp +6 -2
- vendor_llama_cpp_pydist/llama.cpp/src/models/llama.cpp +19 -6
- vendor_llama_cpp_pydist/llama.cpp/src/models/maincoder.cpp +117 -0
- vendor_llama_cpp_pydist/llama.cpp/src/models/mimo2-iswa.cpp +123 -0
- vendor_llama_cpp_pydist/llama.cpp/src/models/models.h +18 -0
- vendor_llama_cpp_pydist/llama.cpp/src/models/modern-bert.cpp +116 -0
- vendor_llama_cpp_pydist/llama.cpp/src/models/openai-moe-iswa.cpp +5 -2
- vendor_llama_cpp_pydist/llama.cpp/src/models/plamo3.cpp +128 -0
- vendor_llama_cpp_pydist/llama.cpp/src/models/smallthinker.cpp +11 -5
- vendor_llama_cpp_pydist/llama.cpp/src/unicode.cpp +23 -14
- vendor_llama_cpp_pydist/llama.cpp/tests/CMakeLists.txt +12 -2
- vendor_llama_cpp_pydist/llama.cpp/tests/test-backend-ops.cpp +286 -65
- vendor_llama_cpp_pydist/llama.cpp/tests/test-backend-sampler.cpp +1237 -0
- vendor_llama_cpp_pydist/llama.cpp/tests/test-chat.cpp +29 -3
- vendor_llama_cpp_pydist/llama.cpp/tests/test-grammar-llguidance.cpp +3 -0
- vendor_llama_cpp_pydist/llama.cpp/tests/test-regex-partial.cpp +14 -14
- vendor_llama_cpp_pydist/llama.cpp/tests/test-tokenizer-0.cpp +1 -1
- vendor_llama_cpp_pydist/llama.cpp/tests/test-tokenizer-1-bpe.cpp +1 -1
- vendor_llama_cpp_pydist/llama.cpp/tests/test-tokenizer-1-spm.cpp +1 -1
- vendor_llama_cpp_pydist/llama.cpp/tools/batched-bench/batched-bench.cpp +11 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/cli/README.md +187 -1
- vendor_llama_cpp_pydist/llama.cpp/tools/cli/cli.cpp +1 -3
- vendor_llama_cpp_pydist/llama.cpp/tools/completion/README.md +179 -7
- vendor_llama_cpp_pydist/llama.cpp/tools/completion/completion.cpp +4 -1
- vendor_llama_cpp_pydist/llama.cpp/tools/fit-params/fit-params.cpp +3 -3
- vendor_llama_cpp_pydist/llama.cpp/tools/llama-bench/llama-bench.cpp +18 -1
- vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/CMakeLists.txt +1 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/clip-impl.h +12 -7
- vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/clip-model.h +3 -1
- vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/clip.cpp +118 -4
- vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/models/models.h +10 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/models/siglip.cpp +9 -4
- vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/models/whisper-enc.cpp +9 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/models/youtuvl.cpp +179 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/mtmd.cpp +5 -1
- vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/mtmd.h +3 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/quantize/quantize.cpp +6 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/server/CMakeLists.txt +0 -8
- vendor_llama_cpp_pydist/llama.cpp/tools/server/README-dev.md +2 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/server/README.md +27 -14
- vendor_llama_cpp_pydist/llama.cpp/tools/server/public/index.html.gz +0 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/server/server-common.cpp +22 -24
- vendor_llama_cpp_pydist/llama.cpp/tools/server/server-common.h +2 -3
- vendor_llama_cpp_pydist/llama.cpp/tools/server/server-context.cpp +453 -267
- vendor_llama_cpp_pydist/llama.cpp/tools/server/server-context.h +52 -15
- vendor_llama_cpp_pydist/llama.cpp/tools/server/server-http.cpp +16 -10
- vendor_llama_cpp_pydist/llama.cpp/tools/server/server-models.cpp +174 -62
- vendor_llama_cpp_pydist/llama.cpp/tools/server/server-models.h +14 -5
- vendor_llama_cpp_pydist/llama.cpp/tools/server/server-queue.cpp +78 -21
- vendor_llama_cpp_pydist/llama.cpp/tools/server/server-queue.h +48 -10
- vendor_llama_cpp_pydist/llama.cpp/tools/server/server-task.cpp +36 -11
- vendor_llama_cpp_pydist/llama.cpp/tools/server/server-task.h +28 -35
- vendor_llama_cpp_pydist/llama.cpp/tools/server/server.cpp +9 -5
- vendor_llama_cpp_pydist/llama.cpp/tools/server/tests/unit/test_chat_completion.py +11 -2
- vendor_llama_cpp_pydist/llama.cpp/tools/server/tests/unit/test_sleep.py +39 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/server/tests/utils.py +3 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte +25 -1
- vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte +66 -13
- vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte +5 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/constants/settings-config.ts +3 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts +125 -11
- vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/services/chat.ts +15 -8
- vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/stores/chat.svelte.ts +12 -3
- vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/stores/settings.svelte.ts +4 -5
- vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/types/api.d.ts +5 -0
- vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/types/settings.d.ts +2 -1
- vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/utils/clipboard.ts +1 -4
- vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/routes/+layout.svelte +1 -1
- llama_cpp_pydist-0.19.0.dist-info/METADATA +0 -2506
- vendor_llama_cpp_pydist/llama.cpp/.github/copilot-instructions.md +0 -262
- {llama_cpp_pydist-0.19.0.dist-info/licenses → llama_cpp_pydist-0.21.0.dist-info}/LICENSE +0 -0
- {llama_cpp_pydist-0.19.0.dist-info → llama_cpp_pydist-0.21.0.dist-info}/top_level.txt +0 -0
|
@@ -379,18 +379,18 @@ enum FaCodePath {
|
|
|
379
379
|
};
|
|
380
380
|
|
|
381
381
|
struct vk_fa_pipeline_state {
|
|
382
|
-
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc)
|
|
383
|
-
: HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {}
|
|
382
|
+
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
|
|
383
|
+
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
|
|
384
384
|
|
|
385
385
|
uint32_t HSK, HSV;
|
|
386
|
-
bool small_rows;
|
|
386
|
+
bool small_rows, small_cache;
|
|
387
387
|
FaCodePath path;
|
|
388
388
|
bool aligned;
|
|
389
389
|
bool f32acc;
|
|
390
390
|
|
|
391
391
|
bool operator<(const vk_fa_pipeline_state &b) const {
|
|
392
|
-
return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) <
|
|
393
|
-
std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc);
|
|
392
|
+
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) <
|
|
393
|
+
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc);
|
|
394
394
|
}
|
|
395
395
|
};
|
|
396
396
|
|
|
@@ -434,8 +434,15 @@ static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGM
|
|
|
434
434
|
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
|
435
435
|
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
|
436
436
|
GGML_OP_RESHAPE };
|
|
437
|
+
|
|
438
|
+
static constexpr std::initializer_list<ggml_op> topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD,
|
|
439
|
+
GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
|
|
440
|
+
GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
|
|
441
|
+
GGML_OP_DIV, GGML_OP_RESHAPE };
|
|
442
|
+
|
|
437
443
|
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
|
438
444
|
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
|
445
|
+
|
|
439
446
|
static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
|
440
447
|
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
|
441
448
|
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
|
@@ -464,6 +471,32 @@ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softma
|
|
|
464
471
|
{ 9, 0, 8 }, // reshape->src[0] == div
|
|
465
472
|
};
|
|
466
473
|
|
|
474
|
+
//node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ]
|
|
475
|
+
//node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ]
|
|
476
|
+
//node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ]
|
|
477
|
+
//node #439 ( ARGSORT): ffn_moe_argsort-10 ( 256K) [Vulka ] use=1: ffn_moe_probs_biased ( 256K) [Vulka ]
|
|
478
|
+
//node #440 ( VIEW): ffn_moe_topk-10 ( 255K) [Vulka ] use=3: ffn_moe_argsort-10 ( 256K) [Vulka ]
|
|
479
|
+
//node #441 ( GET_ROWS): ffn_moe_weights-10 ( 12K) [Vulka ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka ] ffn_moe_topk-10 ( 255K) [Vulka ]
|
|
480
|
+
//node #442 ( RESHAPE): ffn_moe_weights-10 ( ( 12K) [Vulka ] use=2: ffn_moe_weights-10 ( 12K) [Vulka ]
|
|
481
|
+
//node #443 ( SUM_ROWS): ffn_moe_weights_sum- ( 2K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ]
|
|
482
|
+
//node #444 ( CLAMP): ffn_moe_weights_sum_ ( 2K) [Vulka ] use=1: ffn_moe_weights_sum- ( 2K) [Vulka ]
|
|
483
|
+
//node #445 ( DIV): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] ffn_moe_weights_sum_ ( 2K) [Vulka ]
|
|
484
|
+
//node #446 ( RESHAPE): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights_norm ( 12K) [Vulka ]
|
|
485
|
+
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_sigmoid_norm_bias_edges {
|
|
486
|
+
{ 1, 0, 0 }, // reshape->src[0] == sigmoid
|
|
487
|
+
{ 2, 0, 0 }, // add->src[0] == sigmoid
|
|
488
|
+
{ 3, 0, 2 }, // argsort->src[0] == add
|
|
489
|
+
{ 4, 0, 3 }, // view->src[0] == argsort
|
|
490
|
+
{ 5, 0, 1 }, // get_rows->src[0] == reshape
|
|
491
|
+
{ 5, 1, 4 }, // get_rows->src[1] == view
|
|
492
|
+
{ 6, 0, 5 }, // reshape->src[0] == get_rows
|
|
493
|
+
{ 7, 0, 6 }, // sum_rows->src[0] == reshape
|
|
494
|
+
{ 8, 0, 7 }, // clamp->src[0] == sum_rows
|
|
495
|
+
{ 9, 0, 6 }, // div->src[0] == reshape
|
|
496
|
+
{ 9, 1, 8 }, // div->src[1] == clamp
|
|
497
|
+
{10, 0, 9 }, // reshape->src[0] == div
|
|
498
|
+
};
|
|
499
|
+
|
|
467
500
|
// same as early_softmax_norm but ending after the get_rows
|
|
468
501
|
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
|
|
469
502
|
{ 1, 0, 0 }, // reshape->src[0] == softmax
|
|
@@ -491,16 +524,10 @@ enum topk_moe_mode {
|
|
|
491
524
|
TOPK_MOE_EARLY_SOFTMAX,
|
|
492
525
|
TOPK_MOE_EARLY_SOFTMAX_NORM,
|
|
493
526
|
TOPK_MOE_LATE_SOFTMAX,
|
|
527
|
+
TOPK_MOE_SIGMOID_NORM_BIAS,
|
|
494
528
|
TOPK_MOE_COUNT,
|
|
495
529
|
};
|
|
496
530
|
|
|
497
|
-
static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
|
|
498
|
-
topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
|
|
499
|
-
num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
|
|
500
|
-
TOPK_MOE_LATE_SOFTMAX;
|
|
501
|
-
return mode;
|
|
502
|
-
}
|
|
503
|
-
|
|
504
531
|
static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
|
|
505
532
|
{ 1, 0, 0 }, // view->src[0] == rope
|
|
506
533
|
{ 2, 0, 1 }, // set_rows->src[0] == view
|
|
@@ -651,7 +678,7 @@ struct vk_device_struct {
|
|
|
651
678
|
vk_pipeline pipeline_add_id_f32;
|
|
652
679
|
|
|
653
680
|
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
|
654
|
-
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32;
|
|
681
|
+
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
|
|
655
682
|
vk_pipeline pipeline_scale_f32;
|
|
656
683
|
vk_pipeline pipeline_sqr_f32;
|
|
657
684
|
vk_pipeline pipeline_sqrt_f32;
|
|
@@ -689,6 +716,7 @@ struct vk_device_struct {
|
|
|
689
716
|
vk_pipeline pipeline_gelu_quick[2];
|
|
690
717
|
vk_pipeline pipeline_silu[2];
|
|
691
718
|
vk_pipeline pipeline_relu[2];
|
|
719
|
+
vk_pipeline pipeline_xielu[2];
|
|
692
720
|
vk_pipeline pipeline_neg[2];
|
|
693
721
|
vk_pipeline pipeline_tanh[2];
|
|
694
722
|
vk_pipeline pipeline_sigmoid[2];
|
|
@@ -730,13 +758,16 @@ struct vk_device_struct {
|
|
|
730
758
|
|
|
731
759
|
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
|
|
732
760
|
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
|
|
733
|
-
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
|
|
761
|
+
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16, pipeline_rope_multi_f32_f16;
|
|
734
762
|
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
|
735
763
|
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
|
736
764
|
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
|
737
765
|
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
|
|
738
766
|
vk_pipeline pipeline_sum_rows_f32;
|
|
739
767
|
vk_pipeline pipeline_cumsum_f32;
|
|
768
|
+
vk_pipeline pipeline_cumsum_small_f32;
|
|
769
|
+
vk_pipeline pipeline_cumsum_multipass1_f32;
|
|
770
|
+
vk_pipeline pipeline_cumsum_multipass2_f32;
|
|
740
771
|
vk_pipeline pipeline_argmax_f32;
|
|
741
772
|
vk_pipeline pipeline_count_equal_i32;
|
|
742
773
|
std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
|
|
@@ -762,9 +793,10 @@ struct vk_device_struct {
|
|
|
762
793
|
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
|
|
763
794
|
|
|
764
795
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
|
796
|
+
vk_pipeline pipeline_count_experts;
|
|
765
797
|
|
|
766
798
|
// [2] is for whether to take n_experts from spec constant (0) or push constant (1)
|
|
767
|
-
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][
|
|
799
|
+
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
|
|
768
800
|
|
|
769
801
|
std::vector<vk_pipeline_ref> all_pipelines;
|
|
770
802
|
|
|
@@ -855,6 +887,15 @@ struct vk_subbuffer {
|
|
|
855
887
|
}
|
|
856
888
|
};
|
|
857
889
|
|
|
890
|
+
// vk_event is used for the event-related backend interfaces. It uses 'event' for
|
|
891
|
+
// event_wait and 'fence' for event_synchronize. Polling on an event for
|
|
892
|
+
// event_synchronize wouldn't be sufficient to wait for command buffers to complete,
|
|
893
|
+
// and would lead to validation errors.
|
|
894
|
+
struct vk_event {
|
|
895
|
+
vk::Event event;
|
|
896
|
+
vk::Fence fence;
|
|
897
|
+
};
|
|
898
|
+
|
|
858
899
|
struct vk_semaphore {
|
|
859
900
|
vk::Semaphore s;
|
|
860
901
|
uint64_t value;
|
|
@@ -990,6 +1031,16 @@ struct vk_op_push_constants {
|
|
|
990
1031
|
uint32_t KY;
|
|
991
1032
|
float param1;
|
|
992
1033
|
float param2;
|
|
1034
|
+
float param3;
|
|
1035
|
+
float param4;
|
|
1036
|
+
};
|
|
1037
|
+
|
|
1038
|
+
struct vk_op_count_experts_push_constants {
|
|
1039
|
+
uint32_t ne00;
|
|
1040
|
+
uint32_t ne01;
|
|
1041
|
+
uint32_t nb00;
|
|
1042
|
+
uint32_t nb01;
|
|
1043
|
+
uint32_t a_offset;
|
|
993
1044
|
};
|
|
994
1045
|
|
|
995
1046
|
struct vk_op_glu_push_constants {
|
|
@@ -1160,6 +1211,11 @@ struct vk_op_topk_moe_push_constants {
|
|
|
1160
1211
|
uint32_t n_expert_used;
|
|
1161
1212
|
float clamp_min;
|
|
1162
1213
|
float clamp_max;
|
|
1214
|
+
uint32_t gating_func;
|
|
1215
|
+
uint32_t has_bias;
|
|
1216
|
+
uint32_t with_norm;
|
|
1217
|
+
float output_scale;
|
|
1218
|
+
float output_bias;
|
|
1163
1219
|
};
|
|
1164
1220
|
|
|
1165
1221
|
struct vk_op_add_id_push_constants {
|
|
@@ -1180,6 +1236,7 @@ struct vk_op_diag_mask_push_constants {
|
|
|
1180
1236
|
struct vk_op_rope_push_constants {
|
|
1181
1237
|
uint32_t rope_mode;
|
|
1182
1238
|
uint32_t ncols;
|
|
1239
|
+
uint32_t nrows;
|
|
1183
1240
|
uint32_t n_dims;
|
|
1184
1241
|
float freq_scale;
|
|
1185
1242
|
uint32_t p_delta_rows;
|
|
@@ -1258,6 +1315,7 @@ struct vk_op_im2col_push_constants {
|
|
|
1258
1315
|
int32_t s0; int32_t s1;
|
|
1259
1316
|
int32_t p0; int32_t p1;
|
|
1260
1317
|
int32_t d0; int32_t d1;
|
|
1318
|
+
uint32_t batch_IC;
|
|
1261
1319
|
};
|
|
1262
1320
|
|
|
1263
1321
|
struct vk_op_im2col_3d_push_constants {
|
|
@@ -1551,7 +1609,7 @@ class vk_perf_logger {
|
|
|
1551
1609
|
total_op_times += time;
|
|
1552
1610
|
}
|
|
1553
1611
|
std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
|
|
1554
|
-
<< " us";
|
|
1612
|
+
<< " us = " << (total_op_times / 1000.0) << " us";
|
|
1555
1613
|
|
|
1556
1614
|
// If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
|
|
1557
1615
|
auto it = flops.find(t.first);
|
|
@@ -1748,6 +1806,8 @@ struct ggml_backend_vk_context {
|
|
|
1748
1806
|
// Bit 'i' means nodes[start_of_fusion + i] writes to memory.
|
|
1749
1807
|
// If there's no fusion, bit 0 is still set.
|
|
1750
1808
|
int fused_ops_write_mask {};
|
|
1809
|
+
topk_moe_mode fused_topk_moe_mode {};
|
|
1810
|
+
bool fused_topk_moe_scale {};
|
|
1751
1811
|
|
|
1752
1812
|
// for GGML_VK_PERF_LOGGER
|
|
1753
1813
|
std::unique_ptr<vk_perf_logger> perf_logger;
|
|
@@ -2540,6 +2600,15 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
|
|
|
2540
2600
|
);
|
|
2541
2601
|
}
|
|
2542
2602
|
|
|
2603
|
+
static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
|
|
2604
|
+
VK_LOG_DEBUG("ggml_vk_set_event()");
|
|
2605
|
+
|
|
2606
|
+
ctx->s->buffer.setEvent(
|
|
2607
|
+
event,
|
|
2608
|
+
ctx->p->q->stage_flags
|
|
2609
|
+
);
|
|
2610
|
+
}
|
|
2611
|
+
|
|
2543
2612
|
static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) {
|
|
2544
2613
|
VK_LOG_DEBUG("ggml_vk_wait_events()");
|
|
2545
2614
|
if (events.empty()) {
|
|
@@ -2560,10 +2629,10 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
|
|
2560
2629
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
|
2561
2630
|
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
|
2562
2631
|
|
|
2563
|
-
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
|
|
2632
|
+
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
|
|
2564
2633
|
if (hsv >= 192) {
|
|
2565
2634
|
return 2;
|
|
2566
|
-
} else if ((hsv | hsk) & 8) {
|
|
2635
|
+
} else if ((hsv | hsk) & 8 || small_cache) {
|
|
2567
2636
|
return 4;
|
|
2568
2637
|
} else {
|
|
2569
2638
|
return 8;
|
|
@@ -2585,9 +2654,8 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
|
|
2585
2654
|
}
|
|
2586
2655
|
}
|
|
2587
2656
|
|
|
2588
|
-
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
|
|
2657
|
+
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
|
|
2589
2658
|
GGML_UNUSED(clamp);
|
|
2590
|
-
GGML_UNUSED(hsv);
|
|
2591
2659
|
|
|
2592
2660
|
if (path == FA_SCALAR) {
|
|
2593
2661
|
if (small_rows) {
|
|
@@ -2596,9 +2664,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
|
|
2596
2664
|
if ((hsv | hsk) & 8) {
|
|
2597
2665
|
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
|
|
2598
2666
|
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
|
|
2599
|
-
return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
|
|
2667
|
+
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
|
|
2600
2668
|
} else {
|
|
2601
|
-
return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
|
|
2669
|
+
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
|
|
2602
2670
|
}
|
|
2603
2671
|
}
|
|
2604
2672
|
}
|
|
@@ -2627,8 +2695,8 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
|
|
2627
2695
|
return {64, 64};
|
|
2628
2696
|
}
|
|
2629
2697
|
|
|
2630
|
-
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
|
|
2631
|
-
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
|
|
2698
|
+
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
|
|
2699
|
+
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
|
|
2632
2700
|
}
|
|
2633
2701
|
|
|
2634
2702
|
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
|
@@ -2637,7 +2705,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
2637
2705
|
switch (src0_type) {
|
|
2638
2706
|
case GGML_TYPE_IQ1_S:
|
|
2639
2707
|
case GGML_TYPE_IQ1_M:
|
|
2640
|
-
lut_size = 2*2048;
|
|
2708
|
+
lut_size = 2*2048 + 4*2048;
|
|
2641
2709
|
break;
|
|
2642
2710
|
case GGML_TYPE_IQ2_XXS:
|
|
2643
2711
|
lut_size = 8*256;
|
|
@@ -2808,9 +2876,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2808
2876
|
s_mmq_wg_denoms_k = { 32, 64, 1 };
|
|
2809
2877
|
|
|
2810
2878
|
// spec constants and tile sizes for quant matmul_id
|
|
2811
|
-
l_warptile_mmqid = { 256, 128, 128,
|
|
2812
|
-
m_warptile_mmqid = { 256, 128, 64,
|
|
2813
|
-
s_warptile_mmqid = { 256, 128, 64,
|
|
2879
|
+
l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size };
|
|
2880
|
+
m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
|
|
2881
|
+
s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
|
|
2814
2882
|
l_mmqid_wg_denoms = { 128, 128, 1 };
|
|
2815
2883
|
m_mmqid_wg_denoms = { 128, 64, 1 };
|
|
2816
2884
|
s_mmqid_wg_denoms = { 128, 64, 1 };
|
|
@@ -2830,39 +2898,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2830
2898
|
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
|
|
2831
2899
|
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
|
|
2832
2900
|
|
|
2833
|
-
|
|
2834
|
-
|
|
2835
|
-
|
|
2901
|
+
const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
|
|
2902
|
+
|
|
2903
|
+
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
|
|
2904
|
+
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
|
2905
|
+
s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
|
2836
2906
|
|
|
2837
|
-
l_warptile_mmq = { 128,
|
|
2838
|
-
m_warptile_mmq = { 128,
|
|
2839
|
-
s_warptile_mmq = { subgroup_size_32, 32,
|
|
2907
|
+
l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
|
|
2908
|
+
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
|
2909
|
+
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
|
2840
2910
|
|
|
2841
2911
|
// Integer MMQ has a smaller shared memory profile, but heavier register use
|
|
2842
|
-
l_warptile_mmq_int = { 128,
|
|
2843
|
-
m_warptile_mmq_int = { 128,
|
|
2844
|
-
s_warptile_mmq_int = { subgroup_size_32, 32,
|
|
2912
|
+
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
|
2913
|
+
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
|
|
2914
|
+
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, subgroup_size_8 };
|
|
2845
2915
|
|
|
2846
2916
|
// K-quants use even more registers, mitigate by setting WMITER to 1
|
|
2847
|
-
l_warptile_mmq_int_k = { 128,
|
|
2848
|
-
m_warptile_mmq_int_k = { 128,
|
|
2849
|
-
s_warptile_mmq_int_k = { subgroup_size_32,
|
|
2917
|
+
l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
|
|
2918
|
+
m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
|
|
2919
|
+
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, subgroup_size_8 };
|
|
2850
2920
|
|
|
2851
|
-
l_warptile_id = { 128,
|
|
2852
|
-
m_warptile_id = { 128,
|
|
2853
|
-
s_warptile_id = { mul_mat_subgroup_size_16,
|
|
2921
|
+
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
|
|
2922
|
+
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
|
|
2923
|
+
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
|
|
2854
2924
|
|
|
2855
|
-
l_warptile_mmqid = { 128,
|
|
2856
|
-
m_warptile_mmqid = { 128,
|
|
2857
|
-
s_warptile_mmqid = { mul_mat_subgroup_size_32,
|
|
2925
|
+
l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
|
|
2926
|
+
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
|
|
2927
|
+
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
|
|
2858
2928
|
|
|
2859
|
-
l_warptile_mmqid_int = { 128,
|
|
2860
|
-
m_warptile_mmqid_int = { 128,
|
|
2861
|
-
s_warptile_mmqid_int = { mul_mat_subgroup_size_32,
|
|
2929
|
+
l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
|
|
2930
|
+
m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
|
|
2931
|
+
s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
|
|
2862
2932
|
|
|
2863
|
-
l_warptile_mmqid_int_k = { 128,
|
|
2864
|
-
m_warptile_mmqid_int_k = { 128,
|
|
2865
|
-
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32,
|
|
2933
|
+
l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
|
|
2934
|
+
m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
|
|
2935
|
+
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
|
|
2866
2936
|
|
|
2867
2937
|
// chip specific tuning
|
|
2868
2938
|
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
|
|
@@ -2970,11 +3040,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2970
3040
|
align, disable_robustness, require_full_subgroups, required_subgroup_size);
|
|
2971
3041
|
};
|
|
2972
3042
|
|
|
2973
|
-
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
|
2974
|
-
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
|
|
3043
|
+
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> {
|
|
3044
|
+
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
|
|
2975
3045
|
};
|
|
2976
3046
|
|
|
2977
|
-
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
|
3047
|
+
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector<uint32_t> {
|
|
2978
3048
|
// For large number of rows, 128 invocations seems to work best.
|
|
2979
3049
|
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
|
2980
3050
|
// can't use 256 for D==80.
|
|
@@ -2984,7 +3054,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2984
3054
|
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
|
2985
3055
|
? scalar_flash_attention_workgroup_size
|
|
2986
3056
|
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
|
2987
|
-
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
|
|
3057
|
+
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
|
|
2988
3058
|
|
|
2989
3059
|
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
|
2990
3060
|
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
|
@@ -2999,21 +3069,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2999
3069
|
uint32_t HSK = fa.first.HSK; \
|
|
3000
3070
|
uint32_t HSV = fa.first.HSV; \
|
|
3001
3071
|
bool small_rows = fa.first.small_rows; \
|
|
3072
|
+
bool small_cache = fa.first.small_cache; \
|
|
3002
3073
|
FaCodePath path = fa.first.path; \
|
|
3003
3074
|
bool aligned = fa.first.aligned; \
|
|
3004
3075
|
bool f32acc = fa.first.f32acc; \
|
|
3005
3076
|
if (path == FAPATH) { \
|
|
3006
3077
|
if (aligned) { \
|
|
3007
3078
|
if (f32acc) { \
|
|
3008
|
-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
3079
|
+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
3009
3080
|
} else { \
|
|
3010
|
-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
3081
|
+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
3011
3082
|
} \
|
|
3012
3083
|
} else { \
|
|
3013
3084
|
if (f32acc) { \
|
|
3014
|
-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
3085
|
+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
3015
3086
|
} else { \
|
|
3016
|
-
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
3087
|
+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
3017
3088
|
} \
|
|
3018
3089
|
} \
|
|
3019
3090
|
} \
|
|
@@ -3045,17 +3116,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3045
3116
|
#endif
|
|
3046
3117
|
#undef CREATE_FA
|
|
3047
3118
|
|
|
3119
|
+
const int mul_mat_id_param_count = 5;
|
|
3120
|
+
|
|
3048
3121
|
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
3049
3122
|
if (device->coopmat2) {
|
|
3050
3123
|
|
|
3051
3124
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
3052
3125
|
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
3053
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
3054
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
3055
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
3056
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
|
3057
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
|
3058
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
|
3126
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
|
|
3127
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
|
|
3128
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
|
|
3129
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
|
|
3130
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
|
|
3131
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
|
|
3059
3132
|
|
|
3060
3133
|
// Create 2 variants, {f16,f32} accumulator
|
|
3061
3134
|
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
@@ -3091,32 +3164,32 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3091
3164
|
|
|
3092
3165
|
GGML_ASSERT(device->subgroup_ballot);
|
|
3093
3166
|
|
|
3094
|
-
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants,
|
|
3167
|
+
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
|
|
3095
3168
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
3096
3169
|
if (device->coopmat_bf16_support) {
|
|
3097
|
-
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants,
|
|
3170
|
+
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
|
|
3098
3171
|
}
|
|
3099
3172
|
#endif
|
|
3100
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3101
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3102
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3103
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3104
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3105
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3106
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3107
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3108
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3109
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3110
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3111
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3112
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3113
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3114
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3115
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3116
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3117
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3118
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3119
|
-
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3173
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3174
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3175
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3176
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3177
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3178
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3179
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3180
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3181
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3182
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3183
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3184
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3185
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3186
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3187
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3188
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3189
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3190
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3191
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3192
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
|
|
3120
3193
|
#undef CREATE_MM
|
|
3121
3194
|
#undef CREATE_MM2
|
|
3122
3195
|
} else
|
|
@@ -3205,35 +3278,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3205
3278
|
|
|
3206
3279
|
GGML_ASSERT(device->subgroup_ballot);
|
|
3207
3280
|
|
|
3208
|
-
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants,
|
|
3209
|
-
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants,
|
|
3210
|
-
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants,
|
|
3281
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
|
|
3282
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
|
|
3283
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
|
|
3211
3284
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
3212
3285
|
if (device->coopmat_bf16_support) {
|
|
3213
|
-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants,
|
|
3286
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
|
|
3214
3287
|
}
|
|
3215
3288
|
#endif
|
|
3216
3289
|
|
|
3217
|
-
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3218
|
-
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3219
|
-
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3220
|
-
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3221
|
-
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3222
|
-
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3223
|
-
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3224
|
-
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3225
|
-
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3226
|
-
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3227
|
-
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3228
|
-
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3229
|
-
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3230
|
-
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3231
|
-
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3232
|
-
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3233
|
-
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3234
|
-
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3235
|
-
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3236
|
-
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants,
|
|
3290
|
+
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3291
|
+
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3292
|
+
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3293
|
+
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3294
|
+
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3295
|
+
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3296
|
+
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3297
|
+
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3298
|
+
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3299
|
+
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3300
|
+
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3301
|
+
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3302
|
+
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3303
|
+
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3304
|
+
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3305
|
+
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3306
|
+
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3307
|
+
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3308
|
+
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3309
|
+
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
|
|
3237
3310
|
#undef CREATE_MM2
|
|
3238
3311
|
#undef CREATE_MM
|
|
3239
3312
|
} else
|
|
@@ -3318,91 +3391,91 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3318
3391
|
#endif
|
|
3319
3392
|
|
|
3320
3393
|
if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
|
|
3321
|
-
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants,
|
|
3322
|
-
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants,
|
|
3323
|
-
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants,
|
|
3324
|
-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants,
|
|
3325
|
-
|
|
3326
|
-
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3327
|
-
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3328
|
-
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3329
|
-
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3330
|
-
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3331
|
-
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3332
|
-
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3333
|
-
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3334
|
-
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3335
|
-
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3336
|
-
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3337
|
-
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3338
|
-
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3339
|
-
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3340
|
-
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3341
|
-
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3342
|
-
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3343
|
-
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3344
|
-
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3345
|
-
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3394
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3395
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3396
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3397
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3398
|
+
|
|
3399
|
+
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3400
|
+
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3401
|
+
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3402
|
+
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3403
|
+
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3404
|
+
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3405
|
+
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3406
|
+
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3407
|
+
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3408
|
+
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3409
|
+
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3410
|
+
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3411
|
+
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3412
|
+
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3413
|
+
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3414
|
+
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3415
|
+
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3416
|
+
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3417
|
+
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3418
|
+
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3346
3419
|
|
|
3347
3420
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
3348
3421
|
if (device->integer_dot_product) {
|
|
3349
|
-
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants,
|
|
3350
|
-
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants,
|
|
3351
|
-
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants,
|
|
3352
|
-
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants,
|
|
3353
|
-
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants,
|
|
3422
|
+
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3423
|
+
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3424
|
+
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3425
|
+
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3426
|
+
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3354
3427
|
|
|
3355
|
-
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants,
|
|
3428
|
+
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3356
3429
|
|
|
3357
|
-
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants,
|
|
3358
|
-
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants,
|
|
3359
|
-
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants,
|
|
3360
|
-
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants,
|
|
3361
|
-
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants,
|
|
3430
|
+
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3431
|
+
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3432
|
+
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3433
|
+
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3434
|
+
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3362
3435
|
}
|
|
3363
3436
|
#endif
|
|
3364
3437
|
} else {
|
|
3365
|
-
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants,
|
|
3366
|
-
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants,
|
|
3367
|
-
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants,
|
|
3368
|
-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants,
|
|
3369
|
-
|
|
3370
|
-
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3371
|
-
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3372
|
-
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3373
|
-
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3374
|
-
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3375
|
-
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3376
|
-
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3377
|
-
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3378
|
-
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3379
|
-
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3380
|
-
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3381
|
-
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3382
|
-
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3383
|
-
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3384
|
-
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3385
|
-
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3386
|
-
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3387
|
-
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3388
|
-
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3389
|
-
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3438
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3439
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3440
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3441
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3442
|
+
|
|
3443
|
+
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3444
|
+
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3445
|
+
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3446
|
+
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3447
|
+
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3448
|
+
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3449
|
+
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3450
|
+
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3451
|
+
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3452
|
+
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3453
|
+
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3454
|
+
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3455
|
+
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3456
|
+
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3457
|
+
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3458
|
+
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3459
|
+
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3460
|
+
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3461
|
+
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3462
|
+
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3390
3463
|
|
|
3391
3464
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
3392
3465
|
if (device->integer_dot_product) {
|
|
3393
|
-
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants,
|
|
3394
|
-
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants,
|
|
3395
|
-
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants,
|
|
3396
|
-
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants,
|
|
3397
|
-
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants,
|
|
3466
|
+
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3467
|
+
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3468
|
+
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3469
|
+
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3470
|
+
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3398
3471
|
|
|
3399
|
-
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants,
|
|
3472
|
+
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3400
3473
|
|
|
3401
|
-
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants,
|
|
3402
|
-
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants,
|
|
3403
|
-
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants,
|
|
3404
|
-
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants,
|
|
3405
|
-
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants,
|
|
3474
|
+
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3475
|
+
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3476
|
+
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3477
|
+
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3478
|
+
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3406
3479
|
}
|
|
3407
3480
|
#endif
|
|
3408
3481
|
}
|
|
@@ -3479,57 +3552,57 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3479
3552
|
#endif
|
|
3480
3553
|
|
|
3481
3554
|
if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
|
|
3482
|
-
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants,
|
|
3483
|
-
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants,
|
|
3484
|
-
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants,
|
|
3485
|
-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants,
|
|
3486
|
-
|
|
3487
|
-
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3488
|
-
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3489
|
-
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3490
|
-
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3491
|
-
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3492
|
-
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3493
|
-
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3494
|
-
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3495
|
-
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3496
|
-
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3497
|
-
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3498
|
-
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3499
|
-
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3500
|
-
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3501
|
-
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3502
|
-
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3503
|
-
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3504
|
-
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3505
|
-
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3506
|
-
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3555
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3556
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3557
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3558
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
|
|
3559
|
+
|
|
3560
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3561
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3562
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3563
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3564
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3565
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3566
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3567
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3568
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3569
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3570
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3571
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3572
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3573
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3574
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3575
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3576
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3577
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3578
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3579
|
+
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
|
|
3507
3580
|
} else {
|
|
3508
|
-
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants,
|
|
3509
|
-
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants,
|
|
3510
|
-
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants,
|
|
3511
|
-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants,
|
|
3512
|
-
|
|
3513
|
-
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3514
|
-
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3515
|
-
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3516
|
-
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3517
|
-
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3518
|
-
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3519
|
-
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3520
|
-
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3521
|
-
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3522
|
-
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3523
|
-
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3524
|
-
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3525
|
-
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3526
|
-
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3527
|
-
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3528
|
-
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3529
|
-
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3530
|
-
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3531
|
-
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3532
|
-
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants,
|
|
3581
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3582
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3583
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3584
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3585
|
+
|
|
3586
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3587
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3588
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3589
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3590
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3591
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3592
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3593
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3594
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3595
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3596
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3597
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3598
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3599
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3600
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3601
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3602
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3603
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3604
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3605
|
+
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3533
3606
|
}
|
|
3534
3607
|
}
|
|
3535
3608
|
// reusing CREATE_MM from the fp32 path
|
|
@@ -3548,7 +3621,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3548
3621
|
s_wg_denoms = { 32, 32, 1 };
|
|
3549
3622
|
|
|
3550
3623
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
3551
|
-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants,
|
|
3624
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
|
|
3552
3625
|
}
|
|
3553
3626
|
#undef CREATE_MM
|
|
3554
3627
|
|
|
@@ -3559,6 +3632,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3559
3632
|
uint32_t rm_kq = 2;
|
|
3560
3633
|
uint32_t rm_stdq_int = 1;
|
|
3561
3634
|
uint32_t rm_kq_int = 1;
|
|
3635
|
+
auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; };
|
|
3562
3636
|
if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
|
3563
3637
|
if (device->architecture == AMD_GCN) {
|
|
3564
3638
|
rm_stdq = 2;
|
|
@@ -3662,6 +3736,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3662
3736
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
|
3663
3737
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
|
3664
3738
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
|
3739
|
+
|
|
3740
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_q8_1_f32", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
|
|
3741
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_q8_1_f32", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
|
|
3742
|
+
|
|
3665
3743
|
}
|
|
3666
3744
|
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
|
|
3667
3745
|
}
|
|
@@ -3708,6 +3786,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3708
3786
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
|
3709
3787
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
|
3710
3788
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
|
3789
|
+
|
|
3790
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
|
|
3791
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
|
|
3711
3792
|
}
|
|
3712
3793
|
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
|
|
3713
3794
|
}
|
|
@@ -3715,6 +3796,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3715
3796
|
#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
3716
3797
|
GGML_UNUSED(rm_stdq_int);
|
|
3717
3798
|
GGML_UNUSED(rm_kq_int);
|
|
3799
|
+
GGML_UNUSED(rm_iq_int);
|
|
3718
3800
|
#endif
|
|
3719
3801
|
|
|
3720
3802
|
// dequant shaders
|
|
@@ -3933,6 +4015,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3933
4015
|
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
|
|
3934
4016
|
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
|
|
3935
4017
|
ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1);
|
|
4018
|
+
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_antialias_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS}, 1);
|
|
3936
4019
|
|
|
3937
4020
|
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
3938
4021
|
|
|
@@ -3973,6 +4056,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3973
4056
|
CREATE_UNARY(gelu_quick)
|
|
3974
4057
|
CREATE_UNARY(silu)
|
|
3975
4058
|
CREATE_UNARY(relu)
|
|
4059
|
+
CREATE_UNARY(xielu)
|
|
3976
4060
|
CREATE_UNARY(neg)
|
|
3977
4061
|
CREATE_UNARY(tanh)
|
|
3978
4062
|
CREATE_UNARY(sigmoid)
|
|
@@ -4054,6 +4138,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4054
4138
|
|
|
4055
4139
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4056
4140
|
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4141
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4057
4142
|
} else {
|
|
4058
4143
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4059
4144
|
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
@@ -4062,6 +4147,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4062
4147
|
|
|
4063
4148
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4064
4149
|
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4150
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
4065
4151
|
}
|
|
4066
4152
|
|
|
4067
4153
|
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
|
|
@@ -4097,10 +4183,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4097
4183
|
|
|
4098
4184
|
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
4099
4185
|
|
|
4100
|
-
|
|
4186
|
+
const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
|
|
4187
|
+
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
|
|
4188
|
+
ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size);
|
|
4189
|
+
ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
|
|
4190
|
+
ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
|
|
4101
4191
|
|
|
4102
4192
|
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
|
4103
4193
|
|
|
4194
|
+
ggml_vk_create_pipeline(device, device->pipeline_count_experts, "count_experts", count_experts_len, count_experts_data, "main", 2, sizeof(vk_op_count_experts_push_constants), {1, 1, 1}, {}, 1, true);
|
|
4195
|
+
|
|
4104
4196
|
for (auto &s : device->pipeline_solve_tri_f32) {
|
|
4105
4197
|
const vk_solve_tri_pipeline_state &state = s.first;
|
|
4106
4198
|
|
|
@@ -4251,9 +4343,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
4251
4343
|
|
|
4252
4344
|
for (uint32_t use_push = 0; use_push < 2; ++use_push) {
|
|
4253
4345
|
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
|
4254
|
-
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][
|
|
4255
|
-
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
|
|
4256
|
-
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
|
|
4346
|
+
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, use_push}, 1, true, true, device->subgroup_size);
|
|
4257
4347
|
}
|
|
4258
4348
|
}
|
|
4259
4349
|
|
|
@@ -5544,6 +5634,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|
|
5544
5634
|
case GGML_TYPE_Q4_K:
|
|
5545
5635
|
case GGML_TYPE_Q5_K:
|
|
5546
5636
|
case GGML_TYPE_Q6_K:
|
|
5637
|
+
case GGML_TYPE_IQ1_S:
|
|
5638
|
+
case GGML_TYPE_IQ1_M:
|
|
5547
5639
|
break;
|
|
5548
5640
|
default:
|
|
5549
5641
|
return nullptr;
|
|
@@ -5700,6 +5792,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
|
|
5700
5792
|
case GGML_TYPE_Q4_K:
|
|
5701
5793
|
case GGML_TYPE_Q5_K:
|
|
5702
5794
|
case GGML_TYPE_Q6_K:
|
|
5795
|
+
case GGML_TYPE_IQ1_S:
|
|
5796
|
+
case GGML_TYPE_IQ1_M:
|
|
5703
5797
|
break;
|
|
5704
5798
|
default:
|
|
5705
5799
|
return nullptr;
|
|
@@ -5898,6 +5992,9 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
|
|
|
5898
5992
|
std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
|
|
5899
5993
|
}
|
|
5900
5994
|
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
|
|
5995
|
+
GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
|
|
5996
|
+
wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
|
|
5997
|
+
wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
|
5901
5998
|
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
|
|
5902
5999
|
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
|
|
5903
6000
|
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
|
|
@@ -6081,13 +6178,8 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
|
|
|
6081
6178
|
}
|
|
6082
6179
|
}
|
|
6083
6180
|
|
|
6084
|
-
static
|
|
6181
|
+
static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
|
|
6085
6182
|
VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
|
|
6086
|
-
// Buffer is already mapped
|
|
6087
|
-
if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
|
|
6088
|
-
std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl;
|
|
6089
|
-
GGML_ABORT("fatal error");
|
|
6090
|
-
}
|
|
6091
6183
|
// Check if src is pinned memory
|
|
6092
6184
|
vk_buffer buf = nullptr;
|
|
6093
6185
|
size_t buf_offset = 0;
|
|
@@ -6112,12 +6204,13 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
|
|
|
6112
6204
|
|
|
6113
6205
|
ggml_vk_sync_buffers(nullptr, subctx);
|
|
6114
6206
|
subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
|
|
6115
|
-
return;
|
|
6207
|
+
return true;
|
|
6116
6208
|
}
|
|
6117
6209
|
VK_LOG_DEBUG("STAGING");
|
|
6118
6210
|
|
|
6119
6211
|
if (!sync_staging) {
|
|
6120
|
-
|
|
6212
|
+
// copy was not handled caller needs to fall back
|
|
6213
|
+
return false;
|
|
6121
6214
|
}
|
|
6122
6215
|
|
|
6123
6216
|
// Staging buffer required
|
|
@@ -6141,9 +6234,10 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
|
|
|
6141
6234
|
deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
|
|
6142
6235
|
}
|
|
6143
6236
|
}
|
|
6237
|
+
return true;
|
|
6144
6238
|
}
|
|
6145
6239
|
|
|
6146
|
-
static
|
|
6240
|
+
static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
|
|
6147
6241
|
VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
|
|
6148
6242
|
return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
|
|
6149
6243
|
}
|
|
@@ -6162,7 +6256,8 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
|
|
|
6162
6256
|
|
|
6163
6257
|
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
|
|
6164
6258
|
ggml_vk_ctx_begin(dst->device, subctx);
|
|
6165
|
-
ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
|
|
6259
|
+
bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
|
|
6260
|
+
GGML_ASSERT(ret);
|
|
6166
6261
|
ggml_vk_ctx_end(subctx);
|
|
6167
6262
|
|
|
6168
6263
|
for (auto& cpy : subctx->in_memcpys) {
|
|
@@ -6497,18 +6592,18 @@ static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context *
|
|
|
6497
6592
|
|
|
6498
6593
|
static void ggml_vk_matmul_id(
|
|
6499
6594
|
ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
|
|
6500
|
-
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
|
|
6595
|
+
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, const vk_subbuffer & expert_count_buf,
|
|
6501
6596
|
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
|
|
6502
6597
|
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
|
|
6503
6598
|
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
|
|
6504
6599
|
uint32_t padded_n) {
|
|
6505
|
-
VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
|
|
6600
|
+
VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), expert_count: (" << expert_count_buf.buffer->buffer << ", " << expert_count_buf.offset << ", " << expert_count_buf.size << "), " <<
|
|
6506
6601
|
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
|
|
6507
6602
|
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
|
|
6508
6603
|
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
|
|
6509
6604
|
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
|
|
6510
6605
|
nei0, nei1, nbi1, ne11, padded_n };
|
|
6511
|
-
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as });
|
|
6606
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids, expert_count_buf }, pc, { m, nei1, n_as });
|
|
6512
6607
|
}
|
|
6513
6608
|
|
|
6514
6609
|
static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
|
|
@@ -6680,7 +6775,12 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
6680
6775
|
|
|
6681
6776
|
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
|
6682
6777
|
|
|
6683
|
-
|
|
6778
|
+
const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]);
|
|
6779
|
+
// clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks.
|
|
6780
|
+
const uint64_t max_elements = std::min<uint64_t>(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits<uint32_t>::max());
|
|
6781
|
+
const uint32_t elements = std::min(ne, static_cast<uint32_t>(max_elements));
|
|
6782
|
+
|
|
6783
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 2>{ ne, num_blocks }, { elements, 1, 1 });
|
|
6684
6784
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
6685
6785
|
}
|
|
6686
6786
|
|
|
@@ -6964,7 +7064,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
|
|
6964
7064
|
// Quantization overhead is not worth it for small k
|
|
6965
7065
|
switch (device->vendor_id) {
|
|
6966
7066
|
case VK_VENDOR_ID_NVIDIA:
|
|
6967
|
-
if (src0_type == GGML_TYPE_Q2_K) {
|
|
7067
|
+
if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
|
|
6968
7068
|
return true;
|
|
6969
7069
|
}
|
|
6970
7070
|
|
|
@@ -7491,6 +7591,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
7491
7591
|
const uint64_t nei0 = ids->ne[0];
|
|
7492
7592
|
const uint64_t nei1 = ids->ne[1];
|
|
7493
7593
|
|
|
7594
|
+
const uint32_t nbi0 = ids->nb[0];
|
|
7494
7595
|
const uint32_t nbi1 = ids->nb[1];
|
|
7495
7596
|
const uint32_t nbi2 = ids->nb[2];
|
|
7496
7597
|
|
|
@@ -7598,6 +7699,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
7598
7699
|
if (quantize_y) {
|
|
7599
7700
|
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
|
7600
7701
|
}
|
|
7702
|
+
vk_pipeline count_experts = ctx->device->pipeline_count_experts;
|
|
7703
|
+
|
|
7704
|
+
uint32_t expert_count_size = sizeof(uint32_t) * n_as;
|
|
7601
7705
|
|
|
7602
7706
|
{
|
|
7603
7707
|
if (
|
|
@@ -7613,6 +7717,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
7613
7717
|
ctx->prealloc_size_y = y_sz;
|
|
7614
7718
|
ggml_vk_preallocate_buffers(ctx, subctx);
|
|
7615
7719
|
}
|
|
7720
|
+
if (ctx->prealloc_size_split_k < expert_count_size) {
|
|
7721
|
+
ctx->prealloc_size_split_k = expert_count_size;
|
|
7722
|
+
ggml_vk_preallocate_buffers(ctx, subctx);
|
|
7723
|
+
}
|
|
7616
7724
|
|
|
7617
7725
|
// Request descriptor sets
|
|
7618
7726
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
@@ -7625,6 +7733,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
7625
7733
|
if (quantize_y) {
|
|
7626
7734
|
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
|
|
7627
7735
|
}
|
|
7736
|
+
ggml_pipeline_request_descriptor_sets(ctx, count_experts, 1);
|
|
7628
7737
|
}
|
|
7629
7738
|
|
|
7630
7739
|
vk_buffer d_D = dst_buf_ctx->dev_buffer;
|
|
@@ -7674,6 +7783,20 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
7674
7783
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
7675
7784
|
}
|
|
7676
7785
|
}
|
|
7786
|
+
// Count how many times each expert is used
|
|
7787
|
+
vk_subbuffer expert_count_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
|
|
7788
|
+
if (ctx->prealloc_split_k_need_sync) {
|
|
7789
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
7790
|
+
}
|
|
7791
|
+
{
|
|
7792
|
+
const std::vector<uint32_t> pc = { (uint32_t)nei0,
|
|
7793
|
+
(uint32_t)nei1,
|
|
7794
|
+
(uint32_t)(nbi0 / ggml_type_size(ids->type)),
|
|
7795
|
+
(uint32_t)(nbi1 / ggml_type_size(ids->type)),
|
|
7796
|
+
(uint32_t)(get_misalign_bytes(ctx, ids) / ggml_type_size(ids->type)) };
|
|
7797
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, count_experts,
|
|
7798
|
+
{ vk_subbuffer{ d_ids, ids_buf_offset, ids_sz }, expert_count_buf }, pc, { (uint32_t)n_as, 1, 1});
|
|
7799
|
+
}
|
|
7677
7800
|
|
|
7678
7801
|
if (x_non_contig) {
|
|
7679
7802
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
|
|
@@ -7681,7 +7804,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
7681
7804
|
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
|
|
7682
7805
|
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
|
|
7683
7806
|
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)x_ne, 1, 1});
|
|
7684
|
-
ggml_vk_sync_buffers(ctx, subctx);
|
|
7685
7807
|
}
|
|
7686
7808
|
if (y_non_contig) {
|
|
7687
7809
|
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
|
|
@@ -7705,6 +7827,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
7705
7827
|
ctx->prealloc_y_last_tensor_used = src1;
|
|
7706
7828
|
}
|
|
7707
7829
|
}
|
|
7830
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
7708
7831
|
|
|
7709
7832
|
uint32_t stride_batch_x = ne00*ne01;
|
|
7710
7833
|
uint32_t stride_batch_y = ne10*ne11;
|
|
@@ -7721,7 +7844,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
7721
7844
|
ggml_vk_matmul_id(
|
|
7722
7845
|
ctx, subctx, pipeline,
|
|
7723
7846
|
{ d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
|
|
7724
|
-
{ d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz },
|
|
7847
|
+
{ d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,
|
|
7725
7848
|
ne01, ne21, ne10, ne10, ne10, ne01,
|
|
7726
7849
|
stride_batch_x, stride_batch_y, ne20*ne21,
|
|
7727
7850
|
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
|
|
@@ -7733,6 +7856,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
7733
7856
|
if (y_non_contig || quantize_y) {
|
|
7734
7857
|
ctx->prealloc_y_need_sync = true;
|
|
7735
7858
|
}
|
|
7859
|
+
ctx->prealloc_split_k_need_sync = true;
|
|
7736
7860
|
}
|
|
7737
7861
|
|
|
7738
7862
|
static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
|
@@ -7982,11 +8106,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
7982
8106
|
}
|
|
7983
8107
|
}
|
|
7984
8108
|
|
|
7985
|
-
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
|
|
8109
|
+
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
|
|
7986
8110
|
// Needs to be kept up to date on shader changes
|
|
7987
8111
|
GGML_UNUSED(hsv);
|
|
7988
8112
|
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
|
7989
|
-
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
|
|
8113
|
+
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
|
|
7990
8114
|
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
7991
8115
|
|
|
7992
8116
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
@@ -8110,6 +8234,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8110
8234
|
uint32_t workgroups_y = (uint32_t)neq2;
|
|
8111
8235
|
uint32_t workgroups_z = (uint32_t)neq3;
|
|
8112
8236
|
|
|
8237
|
+
const bool small_cache = nek1 < 1024;
|
|
8238
|
+
|
|
8113
8239
|
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
|
8114
8240
|
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
|
8115
8241
|
uint32_t max_gqa;
|
|
@@ -8117,7 +8243,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8117
8243
|
case FA_SCALAR:
|
|
8118
8244
|
case FA_COOPMAT1:
|
|
8119
8245
|
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
|
8120
|
-
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
|
|
8246
|
+
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
|
|
8121
8247
|
break;
|
|
8122
8248
|
case FA_COOPMAT2:
|
|
8123
8249
|
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
|
@@ -8151,7 +8277,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8151
8277
|
|
|
8152
8278
|
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
|
|
8153
8279
|
if (path == FA_SCALAR &&
|
|
8154
|
-
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
|
|
8280
|
+
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
|
|
8155
8281
|
small_rows = true;
|
|
8156
8282
|
}
|
|
8157
8283
|
|
|
@@ -8167,7 +8293,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8167
8293
|
v_stride /= 4;
|
|
8168
8294
|
}
|
|
8169
8295
|
|
|
8170
|
-
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
|
|
8296
|
+
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
|
|
8171
8297
|
bool aligned = (KV % alignment) == 0 &&
|
|
8172
8298
|
// the "aligned" shader variant will forcibly align strides, for performance
|
|
8173
8299
|
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
|
@@ -8179,7 +8305,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
8179
8305
|
|
|
8180
8306
|
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
|
8181
8307
|
|
|
8182
|
-
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc);
|
|
8308
|
+
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc);
|
|
8183
8309
|
|
|
8184
8310
|
vk_pipeline pipeline = nullptr;
|
|
8185
8311
|
|
|
@@ -8404,7 +8530,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
8404
8530
|
return nullptr;
|
|
8405
8531
|
case GGML_OP_UPSCALE:
|
|
8406
8532
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
8407
|
-
|
|
8533
|
+
uint32_t mode = (ggml_get_op_params_i32(dst, 0) & (0xFF | GGML_SCALE_FLAG_ANTIALIAS));
|
|
8408
8534
|
switch (mode) {
|
|
8409
8535
|
case GGML_SCALE_MODE_NEAREST:
|
|
8410
8536
|
return ctx->device->pipeline_upscale_nearest_f32;
|
|
@@ -8412,6 +8538,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
8412
8538
|
return ctx->device->pipeline_upscale_bilinear_f32;
|
|
8413
8539
|
case GGML_SCALE_MODE_BICUBIC:
|
|
8414
8540
|
return ctx->device->pipeline_upscale_bicubic_f32;
|
|
8541
|
+
case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS:
|
|
8542
|
+
return ctx->device->pipeline_upscale_bilinear_antialias_f32;
|
|
8415
8543
|
default:
|
|
8416
8544
|
return nullptr;
|
|
8417
8545
|
}
|
|
@@ -8549,6 +8677,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
8549
8677
|
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
|
|
8550
8678
|
case GGML_UNARY_OP_RELU:
|
|
8551
8679
|
return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
|
|
8680
|
+
case GGML_UNARY_OP_XIELU:
|
|
8681
|
+
return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16];
|
|
8552
8682
|
case GGML_UNARY_OP_NEG:
|
|
8553
8683
|
return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
|
|
8554
8684
|
case GGML_UNARY_OP_TANH:
|
|
@@ -8613,10 +8743,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
8613
8743
|
if (ctx->num_additional_fused_ops) {
|
|
8614
8744
|
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
|
8615
8745
|
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
|
8616
|
-
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
|
8617
8746
|
// use n_experts from push constant if it's not equal to the power of two spec constant
|
|
8618
8747
|
bool use_push = dst->ne[0] != (1u << idx);
|
|
8619
|
-
return ctx->device->pipeline_topk_moe[idx][
|
|
8748
|
+
return ctx->device->pipeline_topk_moe[idx][use_push];
|
|
8620
8749
|
}
|
|
8621
8750
|
|
|
8622
8751
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
|
@@ -8654,6 +8783,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
8654
8783
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
8655
8784
|
return ctx->device->pipeline_rope_multi_f32;
|
|
8656
8785
|
}
|
|
8786
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
|
|
8787
|
+
return ctx->device->pipeline_rope_multi_f32_f16;
|
|
8788
|
+
}
|
|
8657
8789
|
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
|
8658
8790
|
return ctx->device->pipeline_rope_multi_f16;
|
|
8659
8791
|
}
|
|
@@ -8686,7 +8818,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
8686
8818
|
return nullptr;
|
|
8687
8819
|
case GGML_OP_CUMSUM:
|
|
8688
8820
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
8689
|
-
|
|
8821
|
+
if (src0->ne[0] <= 512) {
|
|
8822
|
+
return ctx->device->pipeline_cumsum_small_f32;
|
|
8823
|
+
} else {
|
|
8824
|
+
return ctx->device->pipeline_cumsum_f32;
|
|
8825
|
+
}
|
|
8690
8826
|
}
|
|
8691
8827
|
return nullptr;
|
|
8692
8828
|
case GGML_OP_SOLVE_TRI:
|
|
@@ -9057,10 +9193,20 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
9057
9193
|
elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
|
|
9058
9194
|
} break;
|
|
9059
9195
|
case GGML_OP_DIAG_MASK_INF:
|
|
9060
|
-
case GGML_OP_ROPE:
|
|
9061
|
-
case GGML_OP_ROPE_BACK:
|
|
9062
9196
|
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
|
|
9063
9197
|
break;
|
|
9198
|
+
case GGML_OP_ROPE:
|
|
9199
|
+
case GGML_OP_ROPE_BACK:
|
|
9200
|
+
{
|
|
9201
|
+
uint32_t nrows = (uint32_t)ggml_nrows(src0);
|
|
9202
|
+
uint32_t z = 1;
|
|
9203
|
+
if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
|
|
9204
|
+
z = CEIL_DIV(nrows, 32768);
|
|
9205
|
+
nrows = 32768;
|
|
9206
|
+
}
|
|
9207
|
+
elements = { nrows, (uint32_t)ne00, z };
|
|
9208
|
+
|
|
9209
|
+
} break;
|
|
9064
9210
|
case GGML_OP_GET_ROWS:
|
|
9065
9211
|
elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
|
|
9066
9212
|
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
|
@@ -9084,6 +9230,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
9084
9230
|
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
|
|
9085
9231
|
|
|
9086
9232
|
elements = { OW * KW * KH, OH, batch * IC };
|
|
9233
|
+
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
|
9234
|
+
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
|
9087
9235
|
} break;
|
|
9088
9236
|
case GGML_OP_IM2COL_3D:
|
|
9089
9237
|
{
|
|
@@ -9695,14 +9843,14 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
|
|
|
9695
9843
|
|
|
9696
9844
|
ggml_vk_op_f32_opt_step_adamw(
|
|
9697
9845
|
ctx, subctx, dst,
|
|
9698
|
-
{ (uint32_t)n, 0, 0.0f, 0.0f }
|
|
9846
|
+
{ (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }
|
|
9699
9847
|
);
|
|
9700
9848
|
}
|
|
9701
9849
|
|
|
9702
9850
|
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
|
9703
9851
|
const size_t n = ggml_nelements(dst->src[0]);
|
|
9704
9852
|
|
|
9705
|
-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f });
|
|
9853
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f });
|
|
9706
9854
|
}
|
|
9707
9855
|
|
|
9708
9856
|
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
@@ -9788,6 +9936,7 @@ static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, gg
|
|
|
9788
9936
|
1,
|
|
9789
9937
|
ggml_get_op_params_f32(dst, 0),
|
|
9790
9938
|
ggml_get_op_params_f32(dst, 2),
|
|
9939
|
+
0.0f, 0.0f,
|
|
9791
9940
|
};
|
|
9792
9941
|
|
|
9793
9942
|
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
|
|
@@ -9809,6 +9958,7 @@ static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml
|
|
|
9809
9958
|
1,
|
|
9810
9959
|
ggml_get_op_params_f32(dst, 0),
|
|
9811
9960
|
0.0f,
|
|
9961
|
+
0.0f, 0.0f,
|
|
9812
9962
|
};
|
|
9813
9963
|
|
|
9814
9964
|
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
|
|
@@ -9924,13 +10074,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
9924
10074
|
}
|
|
9925
10075
|
|
|
9926
10076
|
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
9927
|
-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
|
|
10077
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
|
|
9928
10078
|
}
|
|
9929
10079
|
|
|
9930
10080
|
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
9931
10081
|
float * op_params = (float *)dst->op_params;
|
|
9932
10082
|
|
|
9933
|
-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
|
|
10083
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
|
|
9934
10084
|
}
|
|
9935
10085
|
|
|
9936
10086
|
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
@@ -9941,7 +10091,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
9941
10091
|
const float eps = float_op_params[1];
|
|
9942
10092
|
const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
|
9943
10093
|
|
|
9944
|
-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f });
|
|
10094
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f });
|
|
9945
10095
|
}
|
|
9946
10096
|
|
|
9947
10097
|
static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
|
|
@@ -9984,7 +10134,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
|
|
|
9984
10134
|
uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
|
|
9985
10135
|
|
|
9986
10136
|
vk_op_rope_push_constants rope {
|
|
9987
|
-
(uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
|
10137
|
+
(uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
|
9988
10138
|
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
|
9989
10139
|
has_ff, (uint32_t)src0->ne[2], nb01, nb02,
|
|
9990
10140
|
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
|
|
@@ -10110,16 +10260,26 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
10110
10260
|
|
|
10111
10261
|
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
10112
10262
|
float * op_params = (float *)dst->op_params;
|
|
10113
|
-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
|
|
10263
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
|
|
10114
10264
|
}
|
|
10115
10265
|
|
|
10116
10266
|
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
10117
10267
|
float * op_params = (float *)dst->op_params;
|
|
10118
|
-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
|
|
10268
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
|
|
10119
10269
|
}
|
|
10120
10270
|
|
|
10121
10271
|
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
10122
|
-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
|
|
10272
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
|
|
10273
|
+
}
|
|
10274
|
+
|
|
10275
|
+
static void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
10276
|
+
float * op_params = (float *)dst->op_params;
|
|
10277
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY,
|
|
10278
|
+
{
|
|
10279
|
+
(uint32_t)ggml_nelements(src0), 0,
|
|
10280
|
+
op_params[1], op_params[2], op_params[3], op_params[4]
|
|
10281
|
+
}
|
|
10282
|
+
);
|
|
10123
10283
|
}
|
|
10124
10284
|
|
|
10125
10285
|
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
@@ -10244,18 +10404,20 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
10244
10404
|
|
|
10245
10405
|
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
10246
10406
|
float * op_params = (float *)dst->op_params;
|
|
10247
|
-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] });
|
|
10407
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f });
|
|
10248
10408
|
}
|
|
10249
10409
|
|
|
10250
10410
|
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
|
|
10251
|
-
topk_moe_mode mode =
|
|
10411
|
+
topk_moe_mode mode = ctx->fused_topk_moe_mode;
|
|
10252
10412
|
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
|
|
10253
|
-
ggml_tensor *
|
|
10254
|
-
|
|
10255
|
-
|
|
10256
|
-
|
|
10413
|
+
ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
|
|
10414
|
+
ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
|
|
10415
|
+
ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
|
|
10416
|
+
(mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] :
|
|
10417
|
+
cgraph->nodes[node_idx + 3];
|
|
10257
10418
|
|
|
10258
10419
|
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
|
10420
|
+
GGML_ASSERT(bias->type == GGML_TYPE_F32);
|
|
10259
10421
|
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
|
10260
10422
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
|
10261
10423
|
|
|
@@ -10270,6 +10432,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
10270
10432
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
10271
10433
|
|
|
10272
10434
|
vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
|
|
10435
|
+
vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);
|
|
10273
10436
|
vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
|
|
10274
10437
|
vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
|
|
10275
10438
|
|
|
@@ -10277,18 +10440,45 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
10277
10440
|
pc.n_rows = n_rows;
|
|
10278
10441
|
pc.n_experts_push = n_experts;
|
|
10279
10442
|
pc.n_expert_used = n_expert_used;
|
|
10443
|
+
pc.clamp_min = -std::numeric_limits<float>::infinity();
|
|
10444
|
+
pc.clamp_max = std::numeric_limits<float>::infinity();
|
|
10280
10445
|
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
|
|
10281
10446
|
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
|
|
10447
|
+
GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
|
|
10448
|
+
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
|
|
10449
|
+
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
|
|
10450
|
+
}
|
|
10451
|
+
if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
|
|
10452
|
+
ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
|
|
10453
|
+
GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
|
|
10282
10454
|
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
|
|
10283
10455
|
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
|
|
10284
10456
|
}
|
|
10285
10457
|
|
|
10458
|
+
#define GATING_FUNC_SOFTMAX 0
|
|
10459
|
+
#define GATING_FUNC_SIGMOID 1
|
|
10460
|
+
#define GATING_FUNC_SOFTMAX_WEIGHT 2
|
|
10461
|
+
|
|
10462
|
+
pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
|
|
10463
|
+
mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT :
|
|
10464
|
+
GATING_FUNC_SOFTMAX;
|
|
10465
|
+
pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
|
|
10466
|
+
pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
|
|
10467
|
+
if (ctx->fused_topk_moe_scale) {
|
|
10468
|
+
GGML_ASSERT(weights->op == GGML_OP_SCALE);
|
|
10469
|
+
pc.output_scale = ggml_get_op_params_f32(weights, 0);
|
|
10470
|
+
pc.output_bias = ggml_get_op_params_f32(weights, 1);
|
|
10471
|
+
} else {
|
|
10472
|
+
pc.output_scale = 1.0f;
|
|
10473
|
+
pc.output_bias = 0.0f;
|
|
10474
|
+
}
|
|
10475
|
+
|
|
10286
10476
|
GGML_ASSERT(n_expert_used <= n_experts);
|
|
10287
10477
|
|
|
10288
10478
|
const uint32_t rows_per_block = 4;
|
|
10289
10479
|
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
|
|
10290
10480
|
|
|
10291
|
-
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements);
|
|
10481
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements);
|
|
10292
10482
|
}
|
|
10293
10483
|
|
|
10294
10484
|
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
|
|
@@ -10536,16 +10726,58 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|
|
10536
10726
|
}
|
|
10537
10727
|
|
|
10538
10728
|
static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
10539
|
-
vk_op_sum_rows_push_constants
|
|
10540
|
-
|
|
10729
|
+
vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
|
|
10730
|
+
// Use the single pass shader when the rows are small or there are enough rows to fill the GPU.
|
|
10731
|
+
// For fewer, larger rows, use the multipass shader to spread each row across SMs.
|
|
10732
|
+
if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) {
|
|
10733
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc);
|
|
10734
|
+
return;
|
|
10735
|
+
}
|
|
10736
|
+
|
|
10737
|
+
// First pass computes partial sums within a block, and stores the last partial
|
|
10738
|
+
// to the temp buffer. Second pass sums the block partials from the temp buffer
|
|
10739
|
+
// and adds that to the result of the first pass.
|
|
10740
|
+
vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32;
|
|
10741
|
+
vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32;
|
|
10742
|
+
GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr);
|
|
10743
|
+
|
|
10744
|
+
ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
|
|
10745
|
+
ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
|
|
10746
|
+
|
|
10747
|
+
std::array<uint32_t, 3> elements;
|
|
10748
|
+
|
|
10749
|
+
elements[0] = dst->ne[0];
|
|
10750
|
+
elements[1] = (uint32_t)ggml_nrows(dst);
|
|
10751
|
+
elements[2] = 1;
|
|
10752
|
+
|
|
10753
|
+
size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst);
|
|
10754
|
+
|
|
10755
|
+
if (ctx->prealloc_size_split_k < temp_size) {
|
|
10756
|
+
ctx->prealloc_size_split_k = temp_size;
|
|
10757
|
+
ggml_vk_preallocate_buffers(ctx, subctx);
|
|
10758
|
+
}
|
|
10759
|
+
|
|
10760
|
+
vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
|
|
10761
|
+
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
|
10762
|
+
vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
|
|
10763
|
+
|
|
10764
|
+
if (ctx->prealloc_split_k_need_sync) {
|
|
10765
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
10766
|
+
}
|
|
10767
|
+
|
|
10768
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements);
|
|
10769
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
10770
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements);
|
|
10771
|
+
|
|
10772
|
+
ctx->prealloc_split_k_need_sync = true;
|
|
10541
10773
|
}
|
|
10542
10774
|
|
|
10543
10775
|
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
10544
|
-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
|
|
10776
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f });
|
|
10545
10777
|
}
|
|
10546
10778
|
|
|
10547
10779
|
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
10548
|
-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
|
|
10780
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
|
|
10549
10781
|
}
|
|
10550
10782
|
|
|
10551
10783
|
static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
@@ -10587,6 +10819,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
10587
10819
|
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
|
10588
10820
|
|
|
10589
10821
|
const uint32_t pelements = OW * KW * KH;
|
|
10822
|
+
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
|
|
10590
10823
|
|
|
10591
10824
|
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
|
10592
10825
|
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
|
|
@@ -10599,7 +10832,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
10599
10832
|
IC, IW, IH, OW, OH, KW, KH,
|
|
10600
10833
|
pelements,
|
|
10601
10834
|
IC * KH * KW,
|
|
10602
|
-
s0, s1, p0, p1, d0, d1,
|
|
10835
|
+
s0, s1, p0, p1, d0, d1, batch * IC
|
|
10603
10836
|
});
|
|
10604
10837
|
}
|
|
10605
10838
|
|
|
@@ -10804,7 +11037,7 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
10804
11037
|
|
|
10805
11038
|
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
10806
11039
|
const float * op_params = (const float *)dst->op_params;
|
|
10807
|
-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
|
|
11040
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
|
|
10808
11041
|
}
|
|
10809
11042
|
|
|
10810
11043
|
#ifdef GGML_VULKAN_RUN_TESTS
|
|
@@ -12029,6 +12262,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
12029
12262
|
|
|
12030
12263
|
break;
|
|
12031
12264
|
case GGML_OP_UNARY:
|
|
12265
|
+
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
|
12266
|
+
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
|
|
12267
|
+
break;
|
|
12268
|
+
}
|
|
12269
|
+
|
|
12032
12270
|
switch (ggml_get_unary_op(node)) {
|
|
12033
12271
|
case GGML_UNARY_OP_EXP:
|
|
12034
12272
|
case GGML_UNARY_OP_SILU:
|
|
@@ -12050,6 +12288,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
12050
12288
|
case GGML_UNARY_OP_TRUNC:
|
|
12051
12289
|
ggml_vk_unary(ctx, compute_ctx, src0, node);
|
|
12052
12290
|
break;
|
|
12291
|
+
case GGML_UNARY_OP_XIELU:
|
|
12292
|
+
ggml_vk_xielu(ctx, compute_ctx, src0, node);
|
|
12293
|
+
break;
|
|
12053
12294
|
default:
|
|
12054
12295
|
return false;
|
|
12055
12296
|
}
|
|
@@ -12073,7 +12314,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
12073
12314
|
|
|
12074
12315
|
break;
|
|
12075
12316
|
case GGML_OP_SOFT_MAX:
|
|
12076
|
-
if (ctx->
|
|
12317
|
+
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
|
12077
12318
|
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
|
|
12078
12319
|
} else {
|
|
12079
12320
|
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);
|
|
@@ -12093,7 +12334,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
12093
12334
|
|
|
12094
12335
|
break;
|
|
12095
12336
|
case GGML_OP_ARGSORT:
|
|
12096
|
-
if (ctx->
|
|
12337
|
+
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
|
12097
12338
|
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
|
|
12098
12339
|
} else {
|
|
12099
12340
|
ggml_vk_argsort(ctx, compute_ctx, src0, node);
|
|
@@ -12643,7 +12884,23 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
|
|
|
12643
12884
|
|
|
12644
12885
|
vk_buffer buf = buf_ctx->dev_buffer;
|
|
12645
12886
|
|
|
12646
|
-
|
|
12887
|
+
auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
|
|
12888
|
+
|
|
12889
|
+
bool ret = ggml_vk_buffer_write_async(transfer_ctx, buf, dst_offset, data, size);
|
|
12890
|
+
|
|
12891
|
+
if (!ret) {
|
|
12892
|
+
ggml_vk_ensure_sync_staging_buffer(ctx, size);
|
|
12893
|
+
ggml_vk_sync_buffers(nullptr, transfer_ctx);
|
|
12894
|
+
|
|
12895
|
+
vk::BufferCopy buffer_cpy;
|
|
12896
|
+
buffer_cpy.srcOffset = 0;
|
|
12897
|
+
buffer_cpy.dstOffset = dst_offset;
|
|
12898
|
+
buffer_cpy.size = size;
|
|
12899
|
+
|
|
12900
|
+
transfer_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
|
|
12901
|
+
deferred_memcpy(ctx->sync_staging->ptr, data, size, &transfer_ctx->in_memcpys);
|
|
12902
|
+
ggml_vk_synchronize(ctx);
|
|
12903
|
+
}
|
|
12647
12904
|
}
|
|
12648
12905
|
|
|
12649
12906
|
static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
@@ -12920,42 +13177,81 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
|
|
|
12920
13177
|
|
|
12921
13178
|
const ggml_tensor * softmax;
|
|
12922
13179
|
const ggml_tensor * weights;
|
|
13180
|
+
const ggml_tensor * get_rows;
|
|
13181
|
+
const ggml_tensor * argsort;
|
|
12923
13182
|
|
|
12924
13183
|
switch (mode) {
|
|
12925
13184
|
case TOPK_MOE_EARLY_SOFTMAX_NORM:
|
|
12926
13185
|
softmax = cgraph->nodes[node_idx + 0];
|
|
12927
13186
|
weights = cgraph->nodes[node_idx + 9];
|
|
13187
|
+
get_rows = cgraph->nodes[node_idx + 4];
|
|
13188
|
+
argsort = cgraph->nodes[node_idx + 2];
|
|
13189
|
+
break;
|
|
13190
|
+
case TOPK_MOE_SIGMOID_NORM_BIAS:
|
|
13191
|
+
softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
|
|
13192
|
+
weights = cgraph->nodes[node_idx + 10];
|
|
13193
|
+
get_rows = cgraph->nodes[node_idx + 5];
|
|
13194
|
+
argsort = cgraph->nodes[node_idx + 3];
|
|
13195
|
+
if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {
|
|
13196
|
+
return false;
|
|
13197
|
+
}
|
|
13198
|
+
// bias is expected to be 1D
|
|
13199
|
+
if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
|
|
13200
|
+
!ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
|
|
13201
|
+
return false;
|
|
13202
|
+
}
|
|
13203
|
+
// sigmoid fusion seems to generate infinities on moltenvk
|
|
13204
|
+
if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {
|
|
13205
|
+
return false;
|
|
13206
|
+
}
|
|
12928
13207
|
break;
|
|
12929
13208
|
case TOPK_MOE_EARLY_SOFTMAX:
|
|
12930
13209
|
softmax = cgraph->nodes[node_idx + 0];
|
|
12931
13210
|
weights = cgraph->nodes[node_idx + 4];
|
|
13211
|
+
get_rows = cgraph->nodes[node_idx + 4];
|
|
13212
|
+
argsort = cgraph->nodes[node_idx + 2];
|
|
12932
13213
|
break;
|
|
12933
13214
|
case TOPK_MOE_LATE_SOFTMAX:
|
|
12934
13215
|
softmax = cgraph->nodes[node_idx + 4];
|
|
12935
13216
|
weights = cgraph->nodes[node_idx + 5];
|
|
13217
|
+
get_rows = cgraph->nodes[node_idx + 2];
|
|
13218
|
+
argsort = cgraph->nodes[node_idx + 0];
|
|
12936
13219
|
break;
|
|
12937
13220
|
default:
|
|
12938
13221
|
return false;
|
|
12939
13222
|
}
|
|
12940
13223
|
|
|
12941
|
-
|
|
12942
|
-
|
|
12943
|
-
float scale = op_params[0];
|
|
12944
|
-
float max_bias = op_params[1];
|
|
12945
|
-
|
|
12946
|
-
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
|
|
13224
|
+
ggml_tensor * probs = get_rows->src[0];
|
|
13225
|
+
if (probs->op != GGML_OP_RESHAPE) {
|
|
12947
13226
|
return false;
|
|
12948
13227
|
}
|
|
13228
|
+
probs = probs->src[0];
|
|
13229
|
+
ggml_tensor * selection_probs = argsort->src[0];
|
|
12949
13230
|
|
|
12950
|
-
if (
|
|
13231
|
+
if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {
|
|
12951
13232
|
return false;
|
|
12952
13233
|
}
|
|
12953
13234
|
|
|
12954
|
-
|
|
12955
|
-
if (softmax->src[1] || softmax->src[2]) {
|
|
13235
|
+
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
|
|
12956
13236
|
return false;
|
|
12957
13237
|
}
|
|
12958
13238
|
|
|
13239
|
+
if (softmax->op == GGML_OP_SOFT_MAX) {
|
|
13240
|
+
const float * op_params = (const float *)softmax->op_params;
|
|
13241
|
+
|
|
13242
|
+
float scale = op_params[0];
|
|
13243
|
+
float max_bias = op_params[1];
|
|
13244
|
+
|
|
13245
|
+
if (scale != 1.0f || max_bias != 0.0f) {
|
|
13246
|
+
return false;
|
|
13247
|
+
}
|
|
13248
|
+
|
|
13249
|
+
// don't fuse when masks or sinks are present
|
|
13250
|
+
if (softmax->src[1] || softmax->src[2]) {
|
|
13251
|
+
return false;
|
|
13252
|
+
}
|
|
13253
|
+
}
|
|
13254
|
+
|
|
12959
13255
|
const int n_expert = softmax->ne[0];
|
|
12960
13256
|
if (n_expert > (1 << (num_topk_moe_pipelines-1))) {
|
|
12961
13257
|
return false;
|
|
@@ -12997,9 +13293,9 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
|
|
|
12997
13293
|
return false;
|
|
12998
13294
|
}
|
|
12999
13295
|
|
|
13000
|
-
// Only norm/neox shaders have the fusion code
|
|
13296
|
+
// Only norm/neox/mrope shaders have the fusion code
|
|
13001
13297
|
const int mode = ((const int32_t *) rope->op_params)[2];
|
|
13002
|
-
if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
|
|
13298
|
+
if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_MROPE) {
|
|
13003
13299
|
return false;
|
|
13004
13300
|
}
|
|
13005
13301
|
|
|
@@ -13226,6 +13522,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13226
13522
|
total_mul_mat_bytes += bytes;
|
|
13227
13523
|
}
|
|
13228
13524
|
|
|
13525
|
+
ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
|
|
13526
|
+
ctx->fused_topk_moe_scale = false;
|
|
13229
13527
|
const char *fusion_string {};
|
|
13230
13528
|
if (!ctx->device->disable_fusion) {
|
|
13231
13529
|
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
|
|
@@ -13271,13 +13569,23 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13271
13569
|
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
|
13272
13570
|
// view of argsort writes to memory
|
|
13273
13571
|
ctx->fused_ops_write_mask |= 1 << 3;
|
|
13572
|
+
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
|
|
13274
13573
|
fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
|
|
13574
|
+
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
|
|
13575
|
+
ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
|
|
13576
|
+
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
|
|
13577
|
+
ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;
|
|
13578
|
+
// view of argsort writes to memory
|
|
13579
|
+
ctx->fused_ops_write_mask |= 1 << 4;
|
|
13580
|
+
ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
|
|
13581
|
+
fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
|
|
13275
13582
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
|
13276
13583
|
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
|
13277
13584
|
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
|
13278
13585
|
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
|
13279
13586
|
// view of argsort writes to memory
|
|
13280
13587
|
ctx->fused_ops_write_mask |= 1 << 3;
|
|
13588
|
+
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
|
|
13281
13589
|
fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
|
|
13282
13590
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
|
13283
13591
|
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
|
@@ -13285,8 +13593,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
13285
13593
|
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
|
13286
13594
|
// view of argsort writes to memory
|
|
13287
13595
|
ctx->fused_ops_write_mask |= 1 << 1;
|
|
13596
|
+
ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
|
|
13288
13597
|
fusion_string = "TOPK_MOE_LATE_SOFTMAX";
|
|
13289
13598
|
}
|
|
13599
|
+
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
|
13600
|
+
// Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
|
|
13601
|
+
if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
|
|
13602
|
+
ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
|
|
13603
|
+
ctx->fused_topk_moe_scale = true;
|
|
13604
|
+
ctx->num_additional_fused_ops++;
|
|
13605
|
+
}
|
|
13606
|
+
}
|
|
13290
13607
|
}
|
|
13291
13608
|
ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
|
|
13292
13609
|
|
|
@@ -13465,6 +13782,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|
|
13465
13782
|
if (keep_pattern(topk_moe_early_softmax_norm)) {
|
|
13466
13783
|
continue;
|
|
13467
13784
|
}
|
|
13785
|
+
if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
|
|
13786
|
+
continue;
|
|
13787
|
+
}
|
|
13468
13788
|
if (keep_pattern(topk_moe_early_softmax)) {
|
|
13469
13789
|
continue;
|
|
13470
13790
|
}
|
|
@@ -13491,6 +13811,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|
|
13491
13811
|
}
|
|
13492
13812
|
// Don't pull forward nodes from fusion patterns
|
|
13493
13813
|
if (match_pattern(topk_moe_early_softmax_norm, j) ||
|
|
13814
|
+
match_pattern(topk_moe_sigmoid_norm_bias, j) ||
|
|
13494
13815
|
match_pattern(topk_moe_early_softmax, j) ||
|
|
13495
13816
|
match_pattern(topk_moe_late_softmax, j)) {
|
|
13496
13817
|
continue;
|
|
@@ -13502,7 +13823,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|
|
13502
13823
|
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
|
|
13503
13824
|
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
|
|
13504
13825
|
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
|
|
13505
|
-
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)
|
|
13826
|
+
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
|
|
13827
|
+
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
|
|
13506
13828
|
ok = false;
|
|
13507
13829
|
break;
|
|
13508
13830
|
}
|
|
@@ -13630,11 +13952,62 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|
|
13630
13952
|
}
|
|
13631
13953
|
}
|
|
13632
13954
|
|
|
13955
|
+
static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
|
|
13956
|
+
VK_LOG_DEBUG("ggml_backend_vk_event_record(backend=" << backend << ", event=" << event << ")");
|
|
13957
|
+
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
13958
|
+
vk_event *vkev = (vk_event *)event->context;
|
|
13959
|
+
|
|
13960
|
+
vk_context transfer_ctx;
|
|
13961
|
+
|
|
13962
|
+
if (ctx->transfer_ctx.expired()) {
|
|
13963
|
+
// Initialize new transfer context
|
|
13964
|
+
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
|
13965
|
+
ctx->transfer_ctx = transfer_ctx;
|
|
13966
|
+
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
|
|
13967
|
+
} else {
|
|
13968
|
+
transfer_ctx = ctx->transfer_ctx.lock();
|
|
13969
|
+
}
|
|
13970
|
+
|
|
13971
|
+
// the backend interface doesn't have an explicit reset, so reset it here
|
|
13972
|
+
// before we record the command to set it
|
|
13973
|
+
ctx->device->device.resetEvent(vkev->event);
|
|
13974
|
+
ctx->device->device.resetFences({ vkev->fence });
|
|
13975
|
+
|
|
13976
|
+
ggml_vk_set_event(transfer_ctx, vkev->event);
|
|
13977
|
+
|
|
13978
|
+
ggml_vk_ctx_end(transfer_ctx);
|
|
13979
|
+
|
|
13980
|
+
ggml_vk_submit(transfer_ctx, {vkev->fence});
|
|
13981
|
+
ctx->submit_pending = true;
|
|
13982
|
+
ctx->transfer_ctx.reset();
|
|
13983
|
+
}
|
|
13984
|
+
|
|
13985
|
+
static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
|
|
13986
|
+
VK_LOG_DEBUG("ggml_backend_vk_event_wait(backend=" << backend << ", event=" << event << ")");
|
|
13987
|
+
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
13988
|
+
vk_event *vkev = (vk_event *)event->context;
|
|
13989
|
+
|
|
13990
|
+
vk_context transfer_ctx;
|
|
13991
|
+
|
|
13992
|
+
if (ctx->transfer_ctx.expired()) {
|
|
13993
|
+
// Initialize new transfer context
|
|
13994
|
+
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
|
13995
|
+
ctx->transfer_ctx = transfer_ctx;
|
|
13996
|
+
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
|
|
13997
|
+
} else {
|
|
13998
|
+
transfer_ctx = ctx->transfer_ctx.lock();
|
|
13999
|
+
}
|
|
14000
|
+
|
|
14001
|
+
ggml_vk_wait_events(transfer_ctx, {vkev->event});
|
|
14002
|
+
ggml_vk_ctx_end(transfer_ctx);
|
|
14003
|
+
ctx->transfer_ctx.reset();
|
|
14004
|
+
}
|
|
14005
|
+
|
|
13633
14006
|
// TODO: enable async and synchronize
|
|
13634
14007
|
static ggml_backend_i ggml_backend_vk_interface = {
|
|
13635
14008
|
/* .get_name = */ ggml_backend_vk_name,
|
|
13636
14009
|
/* .free = */ ggml_backend_vk_free,
|
|
13637
|
-
/* .set_tensor_async = */
|
|
14010
|
+
/* .set_tensor_async = */ ggml_backend_vk_set_tensor_async,
|
|
13638
14011
|
/* .get_tensor_async = */ ggml_backend_vk_get_tensor_async,
|
|
13639
14012
|
/* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
|
|
13640
14013
|
/* .synchronize = */ ggml_backend_vk_synchronize,
|
|
@@ -13643,8 +14016,8 @@ static ggml_backend_i ggml_backend_vk_interface = {
|
|
|
13643
14016
|
/* .graph_plan_update = */ NULL,
|
|
13644
14017
|
/* .graph_plan_compute = */ NULL,
|
|
13645
14018
|
/* .graph_compute = */ ggml_backend_vk_graph_compute,
|
|
13646
|
-
/* .event_record = */
|
|
13647
|
-
/* .event_wait = */
|
|
14019
|
+
/* .event_record = */ ggml_backend_vk_event_record,
|
|
14020
|
+
/* .event_wait = */ ggml_backend_vk_event_wait,
|
|
13648
14021
|
/* .graph_optimize = */ ggml_vk_graph_optimize,
|
|
13649
14022
|
};
|
|
13650
14023
|
|
|
@@ -13819,10 +14192,10 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
|
|
|
13819
14192
|
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
|
|
13820
14193
|
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
|
13821
14194
|
props->caps = {
|
|
13822
|
-
/* .async = */
|
|
14195
|
+
/* .async = */ true,
|
|
13823
14196
|
/* .host_buffer = */ true,
|
|
13824
14197
|
/* .buffer_from_host_ptr = */ false,
|
|
13825
|
-
/* .events = */
|
|
14198
|
+
/* .events = */ true,
|
|
13826
14199
|
};
|
|
13827
14200
|
}
|
|
13828
14201
|
|
|
@@ -13842,6 +14215,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
13842
14215
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
13843
14216
|
case GGML_UNARY_OP_SILU:
|
|
13844
14217
|
case GGML_UNARY_OP_RELU:
|
|
14218
|
+
case GGML_UNARY_OP_XIELU:
|
|
13845
14219
|
case GGML_UNARY_OP_NEG:
|
|
13846
14220
|
case GGML_UNARY_OP_TANH:
|
|
13847
14221
|
case GGML_UNARY_OP_SIGMOID:
|
|
@@ -14191,7 +14565,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
14191
14565
|
}
|
|
14192
14566
|
return true;
|
|
14193
14567
|
case GGML_OP_UPSCALE:
|
|
14194
|
-
|
|
14568
|
+
if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
|
|
14569
|
+
if ((op->op_params[0] & 0xFF) != GGML_SCALE_MODE_BILINEAR) {
|
|
14570
|
+
return false;
|
|
14571
|
+
}
|
|
14572
|
+
}
|
|
14573
|
+
return op->src[0]->type == GGML_TYPE_F32;
|
|
14195
14574
|
case GGML_OP_ACC:
|
|
14196
14575
|
return op->src[0]->type == GGML_TYPE_F32;
|
|
14197
14576
|
case GGML_OP_CONCAT:
|
|
@@ -14353,6 +14732,47 @@ static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml
|
|
|
14353
14732
|
UNUSED(dev);
|
|
14354
14733
|
}
|
|
14355
14734
|
|
|
14735
|
+
static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {
|
|
14736
|
+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
14737
|
+
auto device = ggml_vk_get_device(ctx->device);
|
|
14738
|
+
|
|
14739
|
+
vk_event *vkev = new vk_event;
|
|
14740
|
+
if (!vkev) {
|
|
14741
|
+
return nullptr;
|
|
14742
|
+
}
|
|
14743
|
+
|
|
14744
|
+
// The event/fence is expected to initially be in the signaled state.
|
|
14745
|
+
vkev->event = device->device.createEvent({});
|
|
14746
|
+
vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled});
|
|
14747
|
+
device->device.setEvent(vkev->event);
|
|
14748
|
+
|
|
14749
|
+
return new ggml_backend_event {
|
|
14750
|
+
/* .device = */ dev,
|
|
14751
|
+
/* .context = */ vkev,
|
|
14752
|
+
};
|
|
14753
|
+
}
|
|
14754
|
+
|
|
14755
|
+
static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
|
14756
|
+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
14757
|
+
auto device = ggml_vk_get_device(ctx->device);
|
|
14758
|
+
|
|
14759
|
+
vk_event *vkev = (vk_event *)event->context;
|
|
14760
|
+
|
|
14761
|
+
device->device.destroyFence(vkev->fence);
|
|
14762
|
+
device->device.destroyEvent(vkev->event);
|
|
14763
|
+
delete vkev;
|
|
14764
|
+
delete event;
|
|
14765
|
+
}
|
|
14766
|
+
|
|
14767
|
+
static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
|
14768
|
+
VK_LOG_DEBUG("ggml_backend_vk_device_event_synchronize(backend=" << dev << ", event=" << event << ")");
|
|
14769
|
+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
14770
|
+
auto device = ggml_vk_get_device(ctx->device);
|
|
14771
|
+
vk_event *vkev = (vk_event *)event->context;
|
|
14772
|
+
|
|
14773
|
+
VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
|
|
14774
|
+
}
|
|
14775
|
+
|
|
14356
14776
|
static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
|
|
14357
14777
|
/* .get_name = */ ggml_backend_vk_device_get_name,
|
|
14358
14778
|
/* .get_description = */ ggml_backend_vk_device_get_description,
|
|
@@ -14366,9 +14786,9 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
|
|
|
14366
14786
|
/* .supports_op = */ ggml_backend_vk_device_supports_op,
|
|
14367
14787
|
/* .supports_buft = */ ggml_backend_vk_device_supports_buft,
|
|
14368
14788
|
/* .offload_op = */ ggml_backend_vk_device_offload_op,
|
|
14369
|
-
/* .event_new = */
|
|
14370
|
-
/* .event_free = */
|
|
14371
|
-
/* .event_synchronize = */
|
|
14789
|
+
/* .event_new = */ ggml_backend_vk_device_event_new,
|
|
14790
|
+
/* .event_free = */ ggml_backend_vk_device_event_free,
|
|
14791
|
+
/* .event_synchronize = */ ggml_backend_vk_device_event_synchronize,
|
|
14372
14792
|
};
|
|
14373
14793
|
|
|
14374
14794
|
static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
|
|
@@ -14747,7 +15167,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
14747
15167
|
} else if (tensor->op == GGML_OP_LOG) {
|
|
14748
15168
|
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
|
|
14749
15169
|
} else if (tensor->op == GGML_OP_TRI) {
|
|
14750
|
-
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
|
|
15170
|
+
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0));
|
|
14751
15171
|
} else if (tensor->op == GGML_OP_DIAG) {
|
|
14752
15172
|
tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
|
|
14753
15173
|
} else if (tensor->op == GGML_OP_CLAMP) {
|
|
@@ -14835,6 +15255,13 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
14835
15255
|
case GGML_UNARY_OP_RELU:
|
|
14836
15256
|
tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
|
|
14837
15257
|
break;
|
|
15258
|
+
case GGML_UNARY_OP_XIELU:
|
|
15259
|
+
tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0);
|
|
15260
|
+
ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1));
|
|
15261
|
+
ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2));
|
|
15262
|
+
ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3));
|
|
15263
|
+
ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4));
|
|
15264
|
+
break;
|
|
14838
15265
|
case GGML_UNARY_OP_NEG:
|
|
14839
15266
|
tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
|
|
14840
15267
|
break;
|