@novastera-oss/llamarn 0.3.1 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +86 -3
- package/RNLlamaCpp.podspec +1 -1
- package/android/CMakeLists.txt +11 -3
- package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
- package/android/src/main/cpp/include/llama.h +53 -114
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -10
- package/cpp/PureCppImpl.cpp +71 -4
- package/cpp/SystemUtils.cpp +3 -7
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -1
- package/cpp/llama.cpp/Makefile +6 -1605
- package/cpp/llama.cpp/README.md +5 -1
- package/cpp/llama.cpp/common/arg.cpp +230 -51
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
- package/cpp/llama.cpp/common/chat.cpp +539 -8
- package/cpp/llama.cpp/common/chat.h +8 -1
- package/cpp/llama.cpp/common/common.cpp +60 -15
- package/cpp/llama.cpp/common/common.h +64 -15
- package/cpp/llama.cpp/common/speculative.cpp +135 -54
- package/cpp/llama.cpp/common/speculative.h +8 -1
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
- package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
- package/cpp/llama.cpp/flake.nix +0 -5
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
- package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
- package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
- package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
- package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
- package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
- package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
- package/cpp/llama.cpp/include/llama.h +53 -114
- package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -1
- package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
- package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
- package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
- package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
- package/cpp/llama.cpp/src/llama-adapter.h +3 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
- package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +61 -252
- package/cpp/llama.cpp/src/llama-context.h +10 -15
- package/cpp/llama.cpp/src/llama-cparams.h +0 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
- package/cpp/llama.cpp/src/llama-graph.h +90 -51
- package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
- package/cpp/llama.cpp/src/llama-hparams.h +21 -6
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
- package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
- package/cpp/llama.cpp/src/llama-memory.h +13 -10
- package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
- package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
- package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
- package/cpp/llama.cpp/src/llama-model.h +28 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
- package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
- package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
- package/cpp/rn-completion.cpp +3 -27
- package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
- package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
- package/ios/include/chat.h +8 -1
- package/ios/include/common/minja/chat-template.hpp +16 -7
- package/ios/include/common/minja/minja.hpp +47 -12
- package/ios/include/common.h +64 -15
- package/ios/include/llama.h +53 -114
- package/ios/include/speculative.h +8 -1
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/lib/module/NativeRNLlamaCpp.js.map +1 -1
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
- package/package.json +1 -2
- package/src/NativeRNLlamaCpp.ts +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
|
@@ -102,7 +102,9 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
|
|
102
102
|
|
|
103
103
|
struct ggml_backend_vk_context;
|
|
104
104
|
|
|
105
|
-
#define MAX_PARAMETER_COUNT
|
|
105
|
+
#define MAX_PARAMETER_COUNT 12
|
|
106
|
+
// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
|
|
107
|
+
#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)
|
|
106
108
|
|
|
107
109
|
struct vk_pipeline_struct {
|
|
108
110
|
std::string name;
|
|
@@ -113,6 +115,8 @@ struct vk_pipeline_struct {
|
|
|
113
115
|
uint32_t parameter_count;
|
|
114
116
|
std::array<uint32_t, 3> wg_denoms;
|
|
115
117
|
uint32_t align;
|
|
118
|
+
// true if fields have been set by ggml_vk_create_pipeline
|
|
119
|
+
bool initialized {};
|
|
116
120
|
// set to true to request the pipeline is compiled after the dryrun
|
|
117
121
|
bool needed {};
|
|
118
122
|
// set to true when the shader has been compiled
|
|
@@ -222,21 +226,7 @@ enum vk_device_architecture {
|
|
|
222
226
|
AMD_RDNA2,
|
|
223
227
|
AMD_RDNA3,
|
|
224
228
|
INTEL_XE2,
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
// HSK x HSV
|
|
228
|
-
enum FaHeadSizes {
|
|
229
|
-
FA_HEAD_SIZE_64,
|
|
230
|
-
FA_HEAD_SIZE_80,
|
|
231
|
-
FA_HEAD_SIZE_96,
|
|
232
|
-
FA_HEAD_SIZE_112,
|
|
233
|
-
FA_HEAD_SIZE_128,
|
|
234
|
-
FA_HEAD_SIZE_192,
|
|
235
|
-
FA_HEAD_SIZE_192_128,
|
|
236
|
-
FA_HEAD_SIZE_256,
|
|
237
|
-
FA_HEAD_SIZE_576_512,
|
|
238
|
-
FA_HEAD_SIZE_UNSUPPORTED,
|
|
239
|
-
FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
|
|
229
|
+
NVIDIA_PRE_TURING,
|
|
240
230
|
};
|
|
241
231
|
|
|
242
232
|
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
|
@@ -315,10 +305,64 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
|
|
|
315
305
|
// https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
|
|
316
306
|
return vk_device_architecture::INTEL_XE2;
|
|
317
307
|
}
|
|
308
|
+
} else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {
|
|
309
|
+
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
|
310
|
+
|
|
311
|
+
bool cooperative_matrix = false;
|
|
312
|
+
|
|
313
|
+
// Detect "pre-turing" based on lack of coopmat support.
|
|
314
|
+
for (const auto& properties : ext_props) {
|
|
315
|
+
if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
|
|
316
|
+
cooperative_matrix = true;
|
|
317
|
+
break;
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
if (!cooperative_matrix) {
|
|
322
|
+
return vk_device_architecture::NVIDIA_PRE_TURING;
|
|
323
|
+
}
|
|
318
324
|
}
|
|
319
325
|
return vk_device_architecture::OTHER;
|
|
320
326
|
}
|
|
321
327
|
|
|
328
|
+
enum vk_conv_shapes {
|
|
329
|
+
CONV_SHAPE_128x128,
|
|
330
|
+
CONV_SHAPE_64x32,
|
|
331
|
+
CONV_SHAPE_32x256,
|
|
332
|
+
CONV_SHAPE_COUNT,
|
|
333
|
+
};
|
|
334
|
+
|
|
335
|
+
enum dmmv_wg_sizes {
|
|
336
|
+
DMMV_WG_SIZE_SUBGROUP,
|
|
337
|
+
DMMV_WG_SIZE_LARGE,
|
|
338
|
+
DMMV_WG_SIZE_COUNT,
|
|
339
|
+
};
|
|
340
|
+
|
|
341
|
+
enum FaCodePath {
|
|
342
|
+
FA_SCALAR,
|
|
343
|
+
FA_COOPMAT1,
|
|
344
|
+
FA_COOPMAT2,
|
|
345
|
+
};
|
|
346
|
+
|
|
347
|
+
struct vk_fa_pipeline_state {
|
|
348
|
+
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc)
|
|
349
|
+
: HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {}
|
|
350
|
+
|
|
351
|
+
uint32_t HSK, HSV;
|
|
352
|
+
bool small_rows;
|
|
353
|
+
FaCodePath path;
|
|
354
|
+
bool aligned;
|
|
355
|
+
bool f32acc;
|
|
356
|
+
|
|
357
|
+
bool operator<(const vk_fa_pipeline_state &b) const {
|
|
358
|
+
return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) <
|
|
359
|
+
std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc);
|
|
360
|
+
}
|
|
361
|
+
};
|
|
362
|
+
|
|
363
|
+
static constexpr uint32_t num_argsort_pipelines = 11;
|
|
364
|
+
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
|
365
|
+
|
|
322
366
|
struct vk_device_struct {
|
|
323
367
|
std::recursive_mutex mutex;
|
|
324
368
|
|
|
@@ -344,6 +388,11 @@ struct vk_device_struct {
|
|
|
344
388
|
bool float_controls_rte_fp16;
|
|
345
389
|
bool subgroup_add;
|
|
346
390
|
bool subgroup_shuffle;
|
|
391
|
+
bool subgroup_ballot;
|
|
392
|
+
bool multi_add;
|
|
393
|
+
|
|
394
|
+
bool add_rms_fusion;
|
|
395
|
+
uint32_t partials_binding_alignment;
|
|
347
396
|
|
|
348
397
|
bool integer_dot_product;
|
|
349
398
|
|
|
@@ -405,8 +454,8 @@ struct vk_device_struct {
|
|
|
405
454
|
vk_pipeline pipeline_quantize_q8_1;
|
|
406
455
|
|
|
407
456
|
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
|
408
|
-
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
|
409
|
-
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
|
457
|
+
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
|
458
|
+
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
|
410
459
|
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
|
|
411
460
|
|
|
412
461
|
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
|
|
@@ -424,11 +473,20 @@ struct vk_device_struct {
|
|
|
424
473
|
vk_pipeline pipeline_mul_norepeat[2][2][2];
|
|
425
474
|
vk_pipeline pipeline_div[2][2][2];
|
|
426
475
|
vk_pipeline pipeline_div_norepeat[2][2][2];
|
|
476
|
+
vk_pipeline pipeline_add_rms[2][2][2];
|
|
477
|
+
vk_pipeline pipeline_add_rms_norepeat[2][2][2];
|
|
478
|
+
|
|
479
|
+
// indexed by num_additional_fused_ops == num_adds - 1
|
|
480
|
+
vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
|
|
481
|
+
vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS];
|
|
482
|
+
|
|
483
|
+
vk_pipeline pipeline_add_id_f32;
|
|
427
484
|
|
|
428
485
|
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
|
429
486
|
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
|
|
430
487
|
vk_pipeline pipeline_scale_f32;
|
|
431
488
|
vk_pipeline pipeline_sqr_f32;
|
|
489
|
+
vk_pipeline pipeline_sqrt_f32;
|
|
432
490
|
vk_pipeline pipeline_sin_f32;
|
|
433
491
|
vk_pipeline pipeline_cos_f32;
|
|
434
492
|
vk_pipeline pipeline_clamp_f32;
|
|
@@ -444,10 +502,13 @@ struct vk_device_struct {
|
|
|
444
502
|
vk_pipeline pipeline_group_norm_f32;
|
|
445
503
|
vk_pipeline pipeline_rms_norm_f32;
|
|
446
504
|
vk_pipeline pipeline_rms_norm_mul_f32;
|
|
505
|
+
vk_pipeline pipeline_rms_norm_partials_f32;
|
|
506
|
+
vk_pipeline pipeline_rms_norm_mul_partials_f32;
|
|
447
507
|
vk_pipeline pipeline_rms_norm_back_f32;
|
|
448
508
|
vk_pipeline pipeline_l2_norm_f32;
|
|
449
509
|
|
|
450
510
|
// [src/dst 0=fp32,1=fp16]
|
|
511
|
+
vk_pipeline pipeline_exp[2];
|
|
451
512
|
vk_pipeline pipeline_gelu[2];
|
|
452
513
|
vk_pipeline pipeline_gelu_erf[2];
|
|
453
514
|
vk_pipeline pipeline_gelu_quick[2];
|
|
@@ -459,6 +520,7 @@ struct vk_device_struct {
|
|
|
459
520
|
vk_pipeline pipeline_geglu[2];
|
|
460
521
|
vk_pipeline pipeline_reglu[2];
|
|
461
522
|
vk_pipeline pipeline_swiglu[2];
|
|
523
|
+
vk_pipeline pipeline_swiglu_oai[2];
|
|
462
524
|
vk_pipeline pipeline_geglu_erf[2];
|
|
463
525
|
vk_pipeline pipeline_geglu_quick[2];
|
|
464
526
|
|
|
@@ -472,7 +534,7 @@ struct vk_device_struct {
|
|
|
472
534
|
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
|
|
473
535
|
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
|
|
474
536
|
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
|
475
|
-
vk_pipeline pipeline_argsort_f32;
|
|
537
|
+
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
|
476
538
|
vk_pipeline pipeline_sum_rows_f32;
|
|
477
539
|
vk_pipeline pipeline_argmax_f32;
|
|
478
540
|
vk_pipeline pipeline_count_equal_i32;
|
|
@@ -483,20 +545,17 @@ struct vk_device_struct {
|
|
|
483
545
|
vk_pipeline pipeline_rwkv_wkv6_f32;
|
|
484
546
|
vk_pipeline pipeline_rwkv_wkv7_f32;
|
|
485
547
|
vk_pipeline pipeline_opt_step_adamw_f32;
|
|
486
|
-
vk_pipeline
|
|
487
|
-
vk_pipeline
|
|
488
|
-
vk_pipeline
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
|
548
|
+
vk_pipeline pipeline_opt_step_sgd_f32;
|
|
549
|
+
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
|
550
|
+
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
|
|
551
|
+
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
|
|
552
|
+
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
|
|
492
553
|
|
|
493
|
-
vk_pipeline
|
|
494
|
-
|
|
495
|
-
vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
|
554
|
+
std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
|
|
496
555
|
|
|
497
556
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
|
498
557
|
|
|
499
|
-
std::
|
|
558
|
+
std::vector<vk_pipeline_ref> all_pipelines;
|
|
500
559
|
|
|
501
560
|
std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
|
|
502
561
|
|
|
@@ -506,6 +565,7 @@ struct vk_device_struct {
|
|
|
506
565
|
ggml_backend_buffer_type buffer_type;
|
|
507
566
|
|
|
508
567
|
bool disable_fusion;
|
|
568
|
+
bool disable_host_visible_vidmem;
|
|
509
569
|
|
|
510
570
|
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
|
511
571
|
std::unique_ptr<vk_memory_logger> memory_logger;
|
|
@@ -526,15 +586,15 @@ struct vk_device_struct {
|
|
|
526
586
|
compute_queue.cmd_pool.destroy(device);
|
|
527
587
|
transfer_queue.cmd_pool.destroy(device);
|
|
528
588
|
|
|
529
|
-
for (auto& pipeline :
|
|
530
|
-
if (pipeline.
|
|
589
|
+
for (auto& pipeline : all_pipelines) {
|
|
590
|
+
if (pipeline.expired()) {
|
|
531
591
|
continue;
|
|
532
592
|
}
|
|
533
593
|
|
|
534
|
-
vk_pipeline pl = pipeline.
|
|
594
|
+
vk_pipeline pl = pipeline.lock();
|
|
535
595
|
ggml_vk_destroy_pipeline(device, pl);
|
|
536
596
|
}
|
|
537
|
-
|
|
597
|
+
all_pipelines.clear();
|
|
538
598
|
|
|
539
599
|
device.destroyDescriptorSetLayout(dsl);
|
|
540
600
|
|
|
@@ -680,6 +740,8 @@ struct vk_op_glu_push_constants {
|
|
|
680
740
|
uint32_t ne00;
|
|
681
741
|
uint32_t ne20;
|
|
682
742
|
uint32_t mode; // 0: default, 1: swapped, 2: split
|
|
743
|
+
float alpha; // for swiglu_oai
|
|
744
|
+
float limit;
|
|
683
745
|
};
|
|
684
746
|
|
|
685
747
|
struct vk_op_unary_push_constants {
|
|
@@ -769,6 +831,28 @@ struct vk_op_binary_push_constants {
|
|
|
769
831
|
float param1; float param2; int32_t param3;
|
|
770
832
|
};
|
|
771
833
|
|
|
834
|
+
struct vk_op_multi_add_push_constants {
|
|
835
|
+
// shape for dst
|
|
836
|
+
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
|
|
837
|
+
|
|
838
|
+
// strides for srcs+dst
|
|
839
|
+
uint32_t nb[MAX_PARAMETER_COUNT][4];
|
|
840
|
+
|
|
841
|
+
uint32_t rms_partials;
|
|
842
|
+
};
|
|
843
|
+
// update multi_add.comp if this changes
|
|
844
|
+
static_assert(MAX_PARAMETER_COUNT == 12);
|
|
845
|
+
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
|
|
846
|
+
|
|
847
|
+
struct vk_op_add_id_push_constants {
|
|
848
|
+
uint32_t ne0;
|
|
849
|
+
uint32_t ne1;
|
|
850
|
+
uint32_t s01;
|
|
851
|
+
uint32_t s02;
|
|
852
|
+
uint32_t s11;
|
|
853
|
+
uint32_t s21;
|
|
854
|
+
};
|
|
855
|
+
|
|
772
856
|
struct vk_op_diag_mask_push_constants {
|
|
773
857
|
uint32_t ncols;
|
|
774
858
|
uint32_t rows_per_channel;
|
|
@@ -810,11 +894,11 @@ struct vk_op_soft_max_push_constants {
|
|
|
810
894
|
float m1;
|
|
811
895
|
uint32_t n_head_log2;
|
|
812
896
|
uint32_t nrows_x;
|
|
897
|
+
uint32_t has_sinks;
|
|
813
898
|
};
|
|
814
899
|
|
|
815
900
|
struct vk_op_argsort_push_constants {
|
|
816
901
|
uint32_t ncols;
|
|
817
|
-
uint32_t ncols_pad;
|
|
818
902
|
int32_t order;
|
|
819
903
|
};
|
|
820
904
|
|
|
@@ -907,8 +991,22 @@ struct vk_op_conv2d_push_constants {
|
|
|
907
991
|
uint32_t nb1;
|
|
908
992
|
uint32_t nb2;
|
|
909
993
|
uint32_t nb3;
|
|
994
|
+
|
|
995
|
+
// init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH
|
|
996
|
+
uint32_t KWmp; uint32_t KWL;
|
|
997
|
+
uint32_t KWKHmp; uint32_t KWKHL;
|
|
998
|
+
uint32_t OWmp; uint32_t OWL;
|
|
999
|
+
uint32_t OWOHmp; uint32_t OWOHL;
|
|
910
1000
|
};
|
|
911
1001
|
|
|
1002
|
+
template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
|
|
1003
|
+
// Compute magic values to divide by KW, KW*KH, OW, OW*OH
|
|
1004
|
+
init_fastdiv_values(p.KW, p.KWmp, p.KWL);
|
|
1005
|
+
init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
|
|
1006
|
+
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
|
|
1007
|
+
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
|
|
1008
|
+
}
|
|
1009
|
+
|
|
912
1010
|
struct vk_op_conv2d_dw_push_constants {
|
|
913
1011
|
uint32_t ne;
|
|
914
1012
|
uint32_t batches;
|
|
@@ -935,6 +1033,39 @@ struct vk_op_upscale_push_constants {
|
|
|
935
1033
|
float sf0; float sf1; float sf2; float sf3;
|
|
936
1034
|
};
|
|
937
1035
|
|
|
1036
|
+
struct vk_op_sum_rows_push_constants
|
|
1037
|
+
{
|
|
1038
|
+
uint32_t n_cols;
|
|
1039
|
+
uint32_t ne01, ne02;
|
|
1040
|
+
uint32_t nb01, nb02, nb03;
|
|
1041
|
+
uint32_t nb11, nb12, nb13;
|
|
1042
|
+
float weight;
|
|
1043
|
+
uint32_t misalign_offsets;
|
|
1044
|
+
uint32_t ne0_12mp, ne0_12L;
|
|
1045
|
+
uint32_t ne0_1mp, ne0_1L;
|
|
1046
|
+
};
|
|
1047
|
+
|
|
1048
|
+
static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
|
|
1049
|
+
uint32_t type_size = (uint32_t)ggml_type_size(src->type);
|
|
1050
|
+
vk_op_sum_rows_push_constants p = {};
|
|
1051
|
+
p.n_cols = (uint32_t)n_cols;
|
|
1052
|
+
p.ne01 = (uint32_t)src->ne[1];
|
|
1053
|
+
p.ne02 = (uint32_t)src->ne[2];
|
|
1054
|
+
p.nb01 = (uint32_t)src->nb[1] / type_size;
|
|
1055
|
+
p.nb02 = (uint32_t)src->nb[2] / type_size;
|
|
1056
|
+
p.nb03 = (uint32_t)src->nb[3] / type_size;
|
|
1057
|
+
p.nb11 = (uint32_t)dst->nb[1] / type_size;
|
|
1058
|
+
p.nb12 = (uint32_t)dst->nb[2] / type_size;
|
|
1059
|
+
p.nb13 = (uint32_t)dst->nb[3] / type_size;
|
|
1060
|
+
p.weight = 1.0f;
|
|
1061
|
+
return p;
|
|
1062
|
+
}
|
|
1063
|
+
|
|
1064
|
+
template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {
|
|
1065
|
+
init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L);
|
|
1066
|
+
init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L);
|
|
1067
|
+
}
|
|
1068
|
+
|
|
938
1069
|
// Allow pre-recording command buffers
|
|
939
1070
|
struct vk_staging_memcpy {
|
|
940
1071
|
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
|
@@ -1055,17 +1186,23 @@ class vk_perf_logger {
|
|
|
1055
1186
|
return;
|
|
1056
1187
|
}
|
|
1057
1188
|
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
|
|
1058
|
-
const uint64_t m
|
|
1059
|
-
const uint64_t n
|
|
1060
|
-
const uint64_t k
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
name += "
|
|
1189
|
+
const uint64_t m = node->src[0]->ne[1];
|
|
1190
|
+
const uint64_t n = node->ne[1];
|
|
1191
|
+
const uint64_t k = node->src[1]->ne[0];
|
|
1192
|
+
const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
|
|
1193
|
+
std::string name = ggml_op_name(node->op);
|
|
1194
|
+
if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) ||
|
|
1195
|
+
(node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) {
|
|
1196
|
+
name += "_VEC";
|
|
1197
|
+
}
|
|
1198
|
+
name += " ";
|
|
1199
|
+
name += ggml_type_name(node->src[0]->type);
|
|
1200
|
+
name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
|
|
1201
|
+
if (batch > 1) {
|
|
1202
|
+
name += " batch=" + std::to_string(batch);
|
|
1066
1203
|
}
|
|
1067
1204
|
timings[name].push_back(time);
|
|
1068
|
-
flops[name].push_back(m * n * (k + (k - 1)));
|
|
1205
|
+
flops[name].push_back(m * n * (k + (k - 1)) * batch);
|
|
1069
1206
|
return;
|
|
1070
1207
|
}
|
|
1071
1208
|
if (node->op == GGML_OP_CONV_2D) {
|
|
@@ -1089,6 +1226,12 @@ class vk_perf_logger {
|
|
|
1089
1226
|
timings[name].push_back(time);
|
|
1090
1227
|
return;
|
|
1091
1228
|
}
|
|
1229
|
+
if (node->op == GGML_OP_RMS_NORM) {
|
|
1230
|
+
std::string name = ggml_op_name(node->op);
|
|
1231
|
+
name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
|
|
1232
|
+
timings[name].push_back(time);
|
|
1233
|
+
return;
|
|
1234
|
+
}
|
|
1092
1235
|
timings[ggml_op_name(node->op)].push_back(time);
|
|
1093
1236
|
}
|
|
1094
1237
|
private:
|
|
@@ -1103,10 +1246,25 @@ struct ggml_backend_vk_context {
|
|
|
1103
1246
|
|
|
1104
1247
|
size_t semaphore_idx, event_idx;
|
|
1105
1248
|
ggml_vk_garbage_collector gc;
|
|
1106
|
-
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
|
|
1107
|
-
vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
|
|
1249
|
+
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset;
|
|
1250
|
+
vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials;
|
|
1108
1251
|
vk::Fence fence, almost_ready_fence;
|
|
1109
1252
|
bool almost_ready_fence_pending {};
|
|
1253
|
+
// Set before op_add and unset after op_rms_norm to indicate that the add should
|
|
1254
|
+
// write partial sums to accumulate the square of the vector components
|
|
1255
|
+
bool do_add_rms_partials;
|
|
1256
|
+
|
|
1257
|
+
// Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
|
|
1258
|
+
vk_pipeline_struct * prealloc_y_last_pipeline_used {};
|
|
1259
|
+
const ggml_tensor * prealloc_y_last_tensor_used {};
|
|
1260
|
+
|
|
1261
|
+
// Track which nodes have been used since the last sync, and whether they were written to
|
|
1262
|
+
std::vector<const ggml_tensor *> unsynced_nodes_written;
|
|
1263
|
+
std::vector<const ggml_tensor *> unsynced_nodes_read;
|
|
1264
|
+
// Track which prealloc buffers have pending reads that need to be synchronized.
|
|
1265
|
+
// These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set),
|
|
1266
|
+
// and set to true after the buffer contents are consumed.
|
|
1267
|
+
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
|
|
1110
1268
|
|
|
1111
1269
|
vk_buffer buffer_pool[MAX_VK_BUFFERS];
|
|
1112
1270
|
|
|
@@ -1340,13 +1498,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
|
|
1340
1498
|
vk::DebugUtilsObjectNameInfoEXT duoni;
|
|
1341
1499
|
duoni.objectType = vk::ObjectType::ePipeline;
|
|
1342
1500
|
duoni.pObjectName = pipeline->name.c_str();
|
|
1343
|
-
duoni.objectHandle = reinterpret_cast
|
|
1501
|
+
duoni.objectHandle = /*reinterpret_cast*/(uint64_t)(static_cast<VkPipeline>(pipeline->pipeline));
|
|
1344
1502
|
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
|
|
1345
1503
|
}
|
|
1346
1504
|
|
|
1347
1505
|
{
|
|
1348
1506
|
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
1349
|
-
device->
|
|
1507
|
+
device->all_pipelines.push_back(pipeline);
|
|
1350
1508
|
}
|
|
1351
1509
|
|
|
1352
1510
|
{
|
|
@@ -1750,6 +1908,8 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
|
|
|
1750
1908
|
} else if (device->uma) {
|
|
1751
1909
|
// Fall back to host memory type
|
|
1752
1910
|
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
|
|
1911
|
+
} else if (device->disable_host_visible_vidmem) {
|
|
1912
|
+
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
1753
1913
|
} else {
|
|
1754
1914
|
// use rebar if available, otherwise fallback to device only visible memory
|
|
1755
1915
|
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
@@ -1781,14 +1941,18 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) {
|
|
|
1781
1941
|
return { buf, 0, VK_WHOLE_SIZE };
|
|
1782
1942
|
}
|
|
1783
1943
|
|
|
1784
|
-
static void ggml_vk_sync_buffers(vk_context&
|
|
1944
|
+
static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) {
|
|
1785
1945
|
VK_LOG_DEBUG("ggml_vk_sync_buffers()");
|
|
1786
1946
|
|
|
1787
|
-
const bool transfer_queue =
|
|
1947
|
+
const bool transfer_queue = subctx->p->q->transfer_only;
|
|
1788
1948
|
|
|
1789
|
-
ctx
|
|
1790
|
-
ctx->
|
|
1791
|
-
|
|
1949
|
+
if (ctx) {
|
|
1950
|
+
ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
|
|
1951
|
+
}
|
|
1952
|
+
|
|
1953
|
+
subctx->s->buffer.pipelineBarrier(
|
|
1954
|
+
subctx->p->q->stage_flags,
|
|
1955
|
+
subctx->p->q->stage_flags,
|
|
1792
1956
|
{},
|
|
1793
1957
|
{ {
|
|
1794
1958
|
{ !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },
|
|
@@ -1815,47 +1979,12 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
|
|
1815
1979
|
);
|
|
1816
1980
|
}
|
|
1817
1981
|
|
|
1818
|
-
enum FaCodePath {
|
|
1819
|
-
FA_SCALAR,
|
|
1820
|
-
FA_COOPMAT1,
|
|
1821
|
-
FA_COOPMAT2,
|
|
1822
|
-
};
|
|
1823
|
-
|
|
1824
|
-
static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
|
|
1825
|
-
if (hsk != 192 && hsk != 576 && hsk != hsv) {
|
|
1826
|
-
return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1827
|
-
}
|
|
1828
|
-
switch (hsk) {
|
|
1829
|
-
case 64: return FA_HEAD_SIZE_64;
|
|
1830
|
-
case 80: return FA_HEAD_SIZE_80;
|
|
1831
|
-
case 96: return FA_HEAD_SIZE_96;
|
|
1832
|
-
case 112: return FA_HEAD_SIZE_112;
|
|
1833
|
-
case 128: return FA_HEAD_SIZE_128;
|
|
1834
|
-
case 192:
|
|
1835
|
-
if (hsv == 192) {
|
|
1836
|
-
return FA_HEAD_SIZE_192;
|
|
1837
|
-
} else if (hsv == 128) {
|
|
1838
|
-
return FA_HEAD_SIZE_192_128;
|
|
1839
|
-
} else {
|
|
1840
|
-
return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1841
|
-
}
|
|
1842
|
-
case 256: return FA_HEAD_SIZE_256;
|
|
1843
|
-
case 576:
|
|
1844
|
-
if (hsv == 512) {
|
|
1845
|
-
return FA_HEAD_SIZE_576_512;
|
|
1846
|
-
} else {
|
|
1847
|
-
return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1848
|
-
}
|
|
1849
|
-
default: return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1850
|
-
}
|
|
1851
|
-
}
|
|
1852
|
-
|
|
1853
1982
|
// number of rows/cols for flash attention shader
|
|
1854
1983
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
|
1855
1984
|
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
|
1856
1985
|
|
|
1857
1986
|
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
|
|
1858
|
-
if (hsv >=
|
|
1987
|
+
if (hsv >= 192) {
|
|
1859
1988
|
return 2;
|
|
1860
1989
|
} else {
|
|
1861
1990
|
return 8;
|
|
@@ -1885,7 +2014,13 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
|
|
1885
2014
|
if (small_rows) {
|
|
1886
2015
|
return {scalar_flash_attention_num_small_rows, 64};
|
|
1887
2016
|
} else {
|
|
1888
|
-
|
|
2017
|
+
if ((hsv | hsk) & 8) {
|
|
2018
|
+
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
|
|
2019
|
+
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
|
|
2020
|
+
return {get_fa_scalar_num_large_rows(hsv), 64};
|
|
2021
|
+
} else {
|
|
2022
|
+
return {get_fa_scalar_num_large_rows(hsv), 32};
|
|
2023
|
+
}
|
|
1889
2024
|
}
|
|
1890
2025
|
}
|
|
1891
2026
|
|
|
@@ -1903,8 +2038,8 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
|
|
1903
2038
|
}
|
|
1904
2039
|
|
|
1905
2040
|
// small cols to reduce register count
|
|
1906
|
-
if (ggml_is_quantized(type) || hsk >= 256) {
|
|
1907
|
-
if (hsk >= 512) {
|
|
2041
|
+
if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) {
|
|
2042
|
+
if (hsk >= 512 || hsv >= 512) {
|
|
1908
2043
|
return {32, 32};
|
|
1909
2044
|
} else {
|
|
1910
2045
|
return {64, 32};
|
|
@@ -1913,6 +2048,10 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
|
|
1913
2048
|
return {64, 64};
|
|
1914
2049
|
}
|
|
1915
2050
|
|
|
2051
|
+
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
|
|
2052
|
+
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
|
|
2053
|
+
}
|
|
2054
|
+
|
|
1916
2055
|
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
|
1917
2056
|
|
|
1918
2057
|
uint32_t lut_size = 0;
|
|
@@ -1938,6 +2077,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
1938
2077
|
break;
|
|
1939
2078
|
case GGML_TYPE_IQ4_NL:
|
|
1940
2079
|
case GGML_TYPE_IQ4_XS:
|
|
2080
|
+
case GGML_TYPE_MXFP4:
|
|
1941
2081
|
lut_size = 4*16;
|
|
1942
2082
|
break;
|
|
1943
2083
|
default:
|
|
@@ -1950,10 +2090,11 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
1950
2090
|
const uint32_t warps = warptile[0] / warptile[10];
|
|
1951
2091
|
|
|
1952
2092
|
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
|
1953
|
-
const uint32_t mmid_row_ids = mul_mat_id ? (
|
|
2093
|
+
const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
|
|
1954
2094
|
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
|
2095
|
+
const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
|
|
1955
2096
|
|
|
1956
|
-
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
|
2097
|
+
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
|
|
1957
2098
|
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
1958
2099
|
|
|
1959
2100
|
VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
|
|
@@ -2037,8 +2178,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2037
2178
|
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
|
|
2038
2179
|
const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
|
|
2039
2180
|
|
|
2181
|
+
const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
|
|
2182
|
+
const uint32_t mul_mat_subgroup_size_8 = std::max(mul_mat_subgroup_size, 8u);
|
|
2183
|
+
const uint32_t mul_mat_subgroup_size_16 = std::max(mul_mat_subgroup_size, 16u);
|
|
2184
|
+
const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u);
|
|
2185
|
+
|
|
2186
|
+
const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) ||
|
|
2187
|
+
(device->subgroup_size_control && device->subgroup_max_size >= 16);
|
|
2188
|
+
|
|
2040
2189
|
// mulmat
|
|
2041
2190
|
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
|
|
2191
|
+
l_warptile_id, m_warptile_id, s_warptile_id,
|
|
2042
2192
|
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
|
|
2043
2193
|
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
|
|
2044
2194
|
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
|
|
@@ -2067,17 +2217,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2067
2217
|
s_mmq_wg_denoms = { 32, 64, 1 };
|
|
2068
2218
|
|
|
2069
2219
|
// spec constants and tile sizes for quant matmul (Qi_K)
|
|
2070
|
-
l_warptile_mmq_k = { 256,
|
|
2071
|
-
m_warptile_mmq_k = { 256,
|
|
2072
|
-
s_warptile_mmq_k = { 256, 32,
|
|
2073
|
-
l_mmq_wg_denoms_k = {
|
|
2074
|
-
m_mmq_wg_denoms_k = {
|
|
2075
|
-
s_mmq_wg_denoms_k = { 32,
|
|
2220
|
+
l_warptile_mmq_k = { 256, 128, 256, 64, 1 };
|
|
2221
|
+
m_warptile_mmq_k = { 256, 128, 128, 64, 1 };
|
|
2222
|
+
s_warptile_mmq_k = { 256, 32, 64, 128, 0 };
|
|
2223
|
+
l_mmq_wg_denoms_k = { 128, 256, 1 };
|
|
2224
|
+
m_mmq_wg_denoms_k = { 128, 128, 1 };
|
|
2225
|
+
s_mmq_wg_denoms_k = { 32, 64, 1 };
|
|
2076
2226
|
|
|
2077
2227
|
// spec constants and tile sizes for quant matmul_id
|
|
2078
|
-
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
|
|
2079
|
-
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
2080
|
-
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
2228
|
+
l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
|
|
2229
|
+
m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
|
|
2230
|
+
s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
|
|
2081
2231
|
l_mmqid_wg_denoms = { 128, 128, 1 };
|
|
2082
2232
|
m_mmqid_wg_denoms = { 128, 64, 1 };
|
|
2083
2233
|
s_mmqid_wg_denoms = { 128, 64, 1 };
|
|
@@ -2109,9 +2259,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2109
2259
|
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
|
|
2110
2260
|
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
|
|
2111
2261
|
|
|
2262
|
+
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
|
|
2263
|
+
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
|
|
2264
|
+
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
|
|
2265
|
+
|
|
2266
|
+
l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
|
|
2267
|
+
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
|
|
2268
|
+
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
|
|
2269
|
+
|
|
2112
2270
|
// chip specific tuning
|
|
2113
2271
|
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
|
|
2114
2272
|
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
|
2273
|
+
m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
|
2115
2274
|
}
|
|
2116
2275
|
|
|
2117
2276
|
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
|
@@ -2137,14 +2296,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2137
2296
|
}
|
|
2138
2297
|
|
|
2139
2298
|
// Disable mul_mat_id if not enough shared memory is available
|
|
2140
|
-
if (!ggml_vk_matmul_shmem_support(device,
|
|
2299
|
+
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) {
|
|
2141
2300
|
device->mul_mat_id_s[i] = false;
|
|
2142
2301
|
device->mul_mat_id_m[i] = false;
|
|
2143
2302
|
device->mul_mat_id_l[i] = false;
|
|
2144
|
-
} else if (!ggml_vk_matmul_shmem_support(device,
|
|
2303
|
+
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) {
|
|
2145
2304
|
device->mul_mat_id_m[i] = false;
|
|
2146
2305
|
device->mul_mat_id_l[i] = false;
|
|
2147
|
-
} else if (!ggml_vk_matmul_shmem_support(device,
|
|
2306
|
+
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
|
|
2148
2307
|
device->mul_mat_id_l[i] = false;
|
|
2149
2308
|
}
|
|
2150
2309
|
}
|
|
@@ -2177,11 +2336,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2177
2336
|
|
|
2178
2337
|
if (!pipeline) {
|
|
2179
2338
|
pipeline = std::make_shared<vk_pipeline_struct>();
|
|
2339
|
+
}
|
|
2340
|
+
if (!pipeline->initialized) {
|
|
2180
2341
|
pipeline->name = name;
|
|
2181
2342
|
pipeline->parameter_count = parameter_count;
|
|
2182
2343
|
pipeline->push_constant_size = push_constant_size;
|
|
2183
2344
|
pipeline->wg_denoms = wg_denoms;
|
|
2184
2345
|
pipeline->align = align;
|
|
2346
|
+
pipeline->initialized = true;
|
|
2185
2347
|
}
|
|
2186
2348
|
|
|
2187
2349
|
if (!pipeline->needed || pipeline->compiled) {
|
|
@@ -2227,26 +2389,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2227
2389
|
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
|
|
2228
2390
|
};
|
|
2229
2391
|
|
|
2230
|
-
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
|
|
2231
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2232
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2233
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2234
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2235
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2236
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2237
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2238
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2239
|
-
|
|
2240
2392
|
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
|
2241
|
-
|
|
2242
|
-
|
|
2243
|
-
|
|
2244
|
-
|
|
2245
|
-
|
|
2246
|
-
|
|
2247
|
-
|
|
2248
|
-
|
|
2249
|
-
|
|
2393
|
+
for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
|
|
2394
|
+
uint32_t HSK = fa.first.HSK; \
|
|
2395
|
+
uint32_t HSV = fa.first.HSV; \
|
|
2396
|
+
bool small_rows = fa.first.small_rows; \
|
|
2397
|
+
FaCodePath path = fa.first.path; \
|
|
2398
|
+
bool aligned = fa.first.aligned; \
|
|
2399
|
+
bool f32acc = fa.first.f32acc; \
|
|
2400
|
+
if (path == FAPATH) { \
|
|
2401
|
+
if (aligned) { \
|
|
2402
|
+
if (f32acc) { \
|
|
2403
|
+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2404
|
+
} else { \
|
|
2405
|
+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2406
|
+
} \
|
|
2407
|
+
} else { \
|
|
2408
|
+
if (f32acc) { \
|
|
2409
|
+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2410
|
+
} else { \
|
|
2411
|
+
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2412
|
+
} \
|
|
2413
|
+
} \
|
|
2414
|
+
} \
|
|
2415
|
+
}
|
|
2250
2416
|
|
|
2251
2417
|
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
|
2252
2418
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
|
@@ -2269,7 +2435,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2269
2435
|
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
|
|
2270
2436
|
}
|
|
2271
2437
|
#endif
|
|
2272
|
-
#undef CREATE_FA2
|
|
2273
2438
|
#undef CREATE_FA
|
|
2274
2439
|
|
|
2275
2440
|
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
@@ -2314,32 +2479,36 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2314
2479
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2315
2480
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2316
2481
|
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2482
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2483
|
+
|
|
2484
|
+
GGML_ASSERT(device->subgroup_ballot);
|
|
2317
2485
|
|
|
2318
|
-
CREATE_MM2(pipeline_matmul_id_f16,
|
|
2486
|
+
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
|
2319
2487
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
2320
2488
|
if (device->coopmat_bf16_support) {
|
|
2321
|
-
CREATE_MM(pipeline_matmul_id_bf16,
|
|
2489
|
+
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
|
2322
2490
|
}
|
|
2323
2491
|
#endif
|
|
2324
|
-
|
|
2325
|
-
|
|
2326
|
-
|
|
2327
|
-
|
|
2328
|
-
|
|
2329
|
-
|
|
2330
|
-
|
|
2331
|
-
|
|
2332
|
-
|
|
2333
|
-
|
|
2334
|
-
|
|
2335
|
-
|
|
2336
|
-
|
|
2337
|
-
|
|
2338
|
-
|
|
2339
|
-
|
|
2340
|
-
|
|
2341
|
-
|
|
2342
|
-
|
|
2492
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2493
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2494
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2495
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2496
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2497
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2498
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2499
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2500
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2501
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2502
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2503
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2504
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2505
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2506
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2507
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2508
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2509
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2510
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2511
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
2343
2512
|
#undef CREATE_MM
|
|
2344
2513
|
#undef CREATE_MM2
|
|
2345
2514
|
} else
|
|
@@ -2401,6 +2570,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2401
2570
|
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2402
2571
|
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2403
2572
|
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2573
|
+
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2404
2574
|
} else {
|
|
2405
2575
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2406
2576
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
@@ -2422,79 +2592,59 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2422
2592
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2423
2593
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2424
2594
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2595
|
+
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2425
2596
|
}
|
|
2426
2597
|
|
|
2427
|
-
|
|
2428
|
-
|
|
2429
|
-
|
|
2598
|
+
GGML_ASSERT(device->subgroup_ballot);
|
|
2599
|
+
|
|
2600
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2601
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2602
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2430
2603
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
2431
2604
|
if (device->coopmat_bf16_support) {
|
|
2432
|
-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16,
|
|
2605
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2433
2606
|
}
|
|
2434
2607
|
#endif
|
|
2435
2608
|
|
|
2436
|
-
|
|
2437
|
-
|
|
2438
|
-
|
|
2439
|
-
|
|
2440
|
-
|
|
2441
|
-
|
|
2442
|
-
|
|
2443
|
-
|
|
2444
|
-
|
|
2445
|
-
|
|
2446
|
-
|
|
2447
|
-
|
|
2448
|
-
|
|
2449
|
-
|
|
2450
|
-
|
|
2451
|
-
|
|
2452
|
-
|
|
2453
|
-
|
|
2454
|
-
|
|
2455
|
-
|
|
2456
|
-
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2457
|
-
} else {
|
|
2458
|
-
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2459
|
-
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2460
|
-
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2461
|
-
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2462
|
-
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2463
|
-
|
|
2464
|
-
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2465
|
-
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2466
|
-
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2467
|
-
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2468
|
-
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2469
|
-
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2470
|
-
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2471
|
-
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2472
|
-
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2473
|
-
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2474
|
-
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2475
|
-
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2476
|
-
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2477
|
-
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2478
|
-
}
|
|
2609
|
+
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2610
|
+
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2611
|
+
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2612
|
+
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2613
|
+
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2614
|
+
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2615
|
+
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2616
|
+
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2617
|
+
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2618
|
+
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2619
|
+
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2620
|
+
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2621
|
+
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2622
|
+
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2623
|
+
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2624
|
+
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2625
|
+
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2626
|
+
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2627
|
+
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2628
|
+
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2479
2629
|
#undef CREATE_MM2
|
|
2480
2630
|
#undef CREATE_MM
|
|
2481
2631
|
} else
|
|
2482
2632
|
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
2483
2633
|
if (device->fp16) {
|
|
2484
2634
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
2485
|
-
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
2635
|
+
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
|
2486
2636
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
2487
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
2637
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
2488
2638
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
2489
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
2639
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
2490
2640
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2491
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
2641
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
2492
2642
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
2493
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
|
2643
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
2494
2644
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
2495
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
|
2645
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
2496
2646
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2497
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
|
2647
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
2498
2648
|
|
|
2499
2649
|
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
2500
2650
|
if (device->mul_mat ## ID ## _l[TYPE]) { \
|
|
@@ -2511,37 +2661,38 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2511
2661
|
} \
|
|
2512
2662
|
|
|
2513
2663
|
// Create 2 variants, {f16,f32} accumulator
|
|
2514
|
-
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
2515
|
-
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
2516
|
-
CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
2517
|
-
|
|
2518
|
-
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2519
|
-
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2520
|
-
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2521
|
-
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2522
|
-
|
|
2523
|
-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2524
|
-
|
|
2525
|
-
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2526
|
-
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2527
|
-
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2528
|
-
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2529
|
-
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2530
|
-
|
|
2531
|
-
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2532
|
-
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2533
|
-
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2534
|
-
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2535
|
-
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2536
|
-
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2537
|
-
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2538
|
-
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2539
|
-
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2540
|
-
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2541
|
-
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2542
|
-
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2543
|
-
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2544
|
-
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2664
|
+
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
|
2665
|
+
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
|
2666
|
+
CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
|
2667
|
+
|
|
2668
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
2669
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
2670
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
2671
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
2672
|
+
|
|
2673
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
2674
|
+
|
|
2675
|
+
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2676
|
+
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2677
|
+
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2678
|
+
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2679
|
+
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2680
|
+
|
|
2681
|
+
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2682
|
+
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2683
|
+
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2684
|
+
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2685
|
+
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2686
|
+
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2687
|
+
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2688
|
+
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2689
|
+
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2690
|
+
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2691
|
+
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2692
|
+
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2693
|
+
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2694
|
+
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2695
|
+
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2545
2696
|
|
|
2546
2697
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
2547
2698
|
if (device->integer_dot_product) {
|
|
@@ -2553,50 +2704,77 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2553
2704
|
}
|
|
2554
2705
|
#endif
|
|
2555
2706
|
|
|
2556
|
-
|
|
2557
|
-
|
|
2558
|
-
|
|
2559
|
-
|
|
2560
|
-
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
2565
|
-
|
|
2566
|
-
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
|
|
2570
|
-
|
|
2571
|
-
|
|
2572
|
-
|
|
2573
|
-
|
|
2574
|
-
|
|
2575
|
-
|
|
2576
|
-
|
|
2577
|
-
|
|
2578
|
-
|
|
2579
|
-
|
|
2580
|
-
|
|
2581
|
-
|
|
2707
|
+
if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
|
|
2708
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
|
2709
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
|
2710
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
|
2711
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
|
2712
|
+
|
|
2713
|
+
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2714
|
+
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2715
|
+
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2716
|
+
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2717
|
+
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2718
|
+
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2719
|
+
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2720
|
+
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2721
|
+
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2722
|
+
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2723
|
+
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2724
|
+
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2725
|
+
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2726
|
+
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2727
|
+
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2728
|
+
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2729
|
+
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2730
|
+
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2731
|
+
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2732
|
+
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2733
|
+
} else {
|
|
2734
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
|
|
2735
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
|
|
2736
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
|
|
2737
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2738
|
+
|
|
2739
|
+
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2740
|
+
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2741
|
+
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2742
|
+
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2743
|
+
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2744
|
+
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2745
|
+
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2746
|
+
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2747
|
+
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2748
|
+
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2749
|
+
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2750
|
+
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2751
|
+
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2752
|
+
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2753
|
+
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2754
|
+
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2755
|
+
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2756
|
+
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2757
|
+
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2758
|
+
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2759
|
+
}
|
|
2582
2760
|
#undef CREATE_MM2
|
|
2583
2761
|
#undef CREATE_MMQ
|
|
2584
2762
|
#undef CREATE_MM
|
|
2585
2763
|
} else {
|
|
2586
2764
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
2587
|
-
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
2765
|
+
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
|
2588
2766
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
2589
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
2767
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, REQSUBGROUPSIZE > 0, false, REQSUBGROUPSIZE); \
|
|
2590
2768
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
2591
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
2769
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, REQSUBGROUPSIZE > 0, false, REQSUBGROUPSIZE); \
|
|
2592
2770
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2593
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
2771
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, REQSUBGROUPSIZE > 0, false, REQSUBGROUPSIZE); \
|
|
2594
2772
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
2595
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
|
2773
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
2596
2774
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
2597
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
|
2775
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
2598
2776
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2599
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
|
2777
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
2600
2778
|
|
|
2601
2779
|
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
2602
2780
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
@@ -2606,33 +2784,34 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2606
2784
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2607
2785
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
2608
2786
|
|
|
2609
|
-
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2610
|
-
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2611
|
-
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2612
|
-
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2613
|
-
|
|
2614
|
-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2615
|
-
|
|
2616
|
-
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2617
|
-
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2618
|
-
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2619
|
-
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2620
|
-
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2621
|
-
|
|
2622
|
-
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2623
|
-
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2624
|
-
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2625
|
-
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2626
|
-
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2627
|
-
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2628
|
-
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2629
|
-
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2630
|
-
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2631
|
-
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2632
|
-
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2633
|
-
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2634
|
-
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2635
|
-
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2787
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
2788
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
2789
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
2790
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
2791
|
+
|
|
2792
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
2793
|
+
|
|
2794
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2795
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2796
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2797
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2798
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2799
|
+
|
|
2800
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2801
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2802
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2803
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2804
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2805
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2806
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2807
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2808
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2809
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2810
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2811
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2812
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2813
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2814
|
+
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
|
|
2636
2815
|
|
|
2637
2816
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
2638
2817
|
if (device->integer_dot_product) {
|
|
@@ -2644,32 +2823,59 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2644
2823
|
}
|
|
2645
2824
|
#endif
|
|
2646
2825
|
|
|
2647
|
-
|
|
2648
|
-
|
|
2649
|
-
|
|
2650
|
-
|
|
2651
|
-
|
|
2652
|
-
|
|
2653
|
-
|
|
2654
|
-
|
|
2655
|
-
|
|
2656
|
-
|
|
2657
|
-
|
|
2658
|
-
|
|
2659
|
-
|
|
2660
|
-
|
|
2661
|
-
|
|
2662
|
-
|
|
2663
|
-
|
|
2664
|
-
|
|
2665
|
-
|
|
2666
|
-
|
|
2667
|
-
|
|
2668
|
-
|
|
2669
|
-
|
|
2670
|
-
|
|
2671
|
-
|
|
2672
|
-
|
|
2826
|
+
if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
|
|
2827
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
|
2828
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
|
2829
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
|
2830
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
|
2831
|
+
|
|
2832
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2833
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2834
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2835
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2836
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2837
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2838
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2839
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2840
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2841
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2842
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2843
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2844
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2845
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2846
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2847
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2848
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2849
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2850
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2851
|
+
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
2852
|
+
} else {
|
|
2853
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
|
|
2854
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
|
|
2855
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
|
|
2856
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2857
|
+
|
|
2858
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2859
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2860
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2861
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2862
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2863
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2864
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2865
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2866
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2867
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2868
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2869
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2870
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2871
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2872
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2873
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2874
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2875
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2876
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2877
|
+
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2878
|
+
}
|
|
2673
2879
|
}
|
|
2674
2880
|
// reusing CREATE_MM from the fp32 path
|
|
2675
2881
|
if ((device->coopmat2 || device->coopmat_support)
|
|
@@ -2686,8 +2892,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2686
2892
|
m_wg_denoms = { 64, 64, 1 };
|
|
2687
2893
|
s_wg_denoms = { 32, 32, 1 };
|
|
2688
2894
|
|
|
2689
|
-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2690
|
-
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
|
2895
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
|
2896
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
2691
2897
|
}
|
|
2692
2898
|
#undef CREATE_MM
|
|
2693
2899
|
|
|
@@ -2705,52 +2911,61 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2705
2911
|
rm_stdq = 2;
|
|
2706
2912
|
uint32_t rm_iq = 2 * rm_kq;
|
|
2707
2913
|
|
|
2708
|
-
for (uint32_t
|
|
2709
|
-
|
|
2710
|
-
|
|
2711
|
-
|
|
2712
|
-
|
|
2713
|
-
|
|
2714
|
-
|
|
2715
|
-
|
|
2716
|
-
|
|
2717
|
-
|
|
2718
|
-
|
|
2719
|
-
|
|
2720
|
-
|
|
2721
|
-
|
|
2722
|
-
|
|
2723
|
-
|
|
2724
|
-
|
|
2725
|
-
|
|
2726
|
-
|
|
2727
|
-
|
|
2728
|
-
|
|
2729
|
-
|
|
2730
|
-
|
|
2731
|
-
|
|
2732
|
-
|
|
2733
|
-
|
|
2734
|
-
|
|
2735
|
-
|
|
2736
|
-
|
|
2737
|
-
|
|
2738
|
-
|
|
2739
|
-
|
|
2740
|
-
|
|
2741
|
-
|
|
2742
|
-
|
|
2743
|
-
|
|
2744
|
-
|
|
2745
|
-
|
|
2746
|
-
|
|
2747
|
-
|
|
2748
|
-
|
|
2749
|
-
|
|
2750
|
-
|
|
2751
|
-
|
|
2752
|
-
|
|
2753
|
-
|
|
2914
|
+
for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) {
|
|
2915
|
+
uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_16 : (subgroup_size_16 * 4);
|
|
2916
|
+
uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? device->subgroup_size : (device->subgroup_size * 4);
|
|
2917
|
+
|
|
2918
|
+
const bool s = device->subgroup_add && device->architecture != vk_device_architecture::AMD_GCN;
|
|
2919
|
+
|
|
2920
|
+
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
|
2921
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[s], arr_dmmv_f32_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
|
|
2922
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[s], arr_dmmv_f16_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
|
|
2923
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[s], arr_dmmv_bf16_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
|
|
2924
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[s], arr_dmmv_q4_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
|
|
2925
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[s], arr_dmmv_q4_1_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
|
|
2926
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[s], arr_dmmv_q5_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
|
|
2927
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[s], arr_dmmv_q5_1_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
|
|
2928
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[s], arr_dmmv_q8_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true);
|
|
2929
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[s], arr_dmmv_q2_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
|
|
2930
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[s], arr_dmmv_q3_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
|
|
2931
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[s], arr_dmmv_q4_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
|
|
2932
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32", arr_dmmv_q5_k_f32_f32_len[s], arr_dmmv_q5_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
|
|
2933
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32", arr_dmmv_q6_k_f32_f32_len[s], arr_dmmv_q6_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
|
|
2934
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32", arr_dmmv_iq1_s_f32_f32_len[s], arr_dmmv_iq1_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2935
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32", arr_dmmv_iq1_m_f32_f32_len[s], arr_dmmv_iq1_m_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2936
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32", arr_dmmv_iq2_xxs_f32_f32_len[s], arr_dmmv_iq2_xxs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2937
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32", arr_dmmv_iq2_xs_f32_f32_len[s], arr_dmmv_iq2_xs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2938
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32", arr_dmmv_iq2_s_f32_f32_len[s], arr_dmmv_iq2_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2939
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32", arr_dmmv_iq3_xxs_f32_f32_len[s], arr_dmmv_iq3_xxs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2940
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32", arr_dmmv_iq3_s_f32_f32_len[s], arr_dmmv_iq3_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2941
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[s], arr_dmmv_iq4_xs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2942
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[s], arr_dmmv_iq4_nl_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2943
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[s], arr_dmmv_mxfp4_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2944
|
+
|
|
2945
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[s], arr_dmmv_f32_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
|
|
2946
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[s], arr_dmmv_f16_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
|
|
2947
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[s], arr_dmmv_bf16_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
|
|
2948
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[s], arr_dmmv_q4_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
|
|
2949
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[s], arr_dmmv_q4_1_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
|
|
2950
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[s], arr_dmmv_q5_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
|
|
2951
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[s], arr_dmmv_q5_1_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
|
|
2952
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[s], arr_dmmv_q8_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true);
|
|
2953
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[s], arr_dmmv_q2_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
|
|
2954
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[s], arr_dmmv_q3_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
|
|
2955
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[s], arr_dmmv_q4_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
|
|
2956
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32", arr_dmmv_q5_k_f16_f32_len[s], arr_dmmv_q5_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
|
|
2957
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32", arr_dmmv_q6_k_f16_f32_len[s], arr_dmmv_q6_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
|
|
2958
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32", arr_dmmv_iq1_s_f16_f32_len[s], arr_dmmv_iq1_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2959
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32", arr_dmmv_iq1_m_f16_f32_len[s], arr_dmmv_iq1_m_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2960
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32", arr_dmmv_iq2_xxs_f16_f32_len[s], arr_dmmv_iq2_xxs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2961
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32", arr_dmmv_iq2_xs_f16_f32_len[s], arr_dmmv_iq2_xs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2962
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32", arr_dmmv_iq2_s_f16_f32_len[s], arr_dmmv_iq2_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2963
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32", arr_dmmv_iq3_xxs_f16_f32_len[s], arr_dmmv_iq3_xxs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2964
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32", arr_dmmv_iq3_s_f16_f32_len[s], arr_dmmv_iq3_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2965
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[s], arr_dmmv_iq4_xs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2966
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[s], arr_dmmv_iq4_nl_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2967
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[s], arr_dmmv_mxfp4_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
|
|
2968
|
+
}
|
|
2754
2969
|
}
|
|
2755
2970
|
|
|
2756
2971
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
@@ -2775,6 +2990,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2775
2990
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2776
2991
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2777
2992
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2993
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2778
2994
|
|
|
2779
2995
|
// dequant shaders
|
|
2780
2996
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
@@ -2797,6 +3013,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2797
3013
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
|
2798
3014
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
|
2799
3015
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
3016
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
2800
3017
|
|
|
2801
3018
|
// get_rows
|
|
2802
3019
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
@@ -2816,6 +3033,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2816
3033
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2817
3034
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2818
3035
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
3036
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2819
3037
|
|
|
2820
3038
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
2821
3039
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
@@ -2834,9 +3052,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2834
3052
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2835
3053
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2836
3054
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
3055
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2837
3056
|
|
|
2838
3057
|
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
|
2839
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main",
|
|
3058
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
|
2840
3059
|
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
|
2841
3060
|
|
|
2842
3061
|
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
|
@@ -2846,12 +3065,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2846
3065
|
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
|
|
2847
3066
|
}
|
|
2848
3067
|
}
|
|
2849
|
-
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3,
|
|
3068
|
+
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 12 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
|
2850
3069
|
|
|
2851
3070
|
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2852
3071
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2853
|
-
|
|
2854
|
-
ggml_vk_create_pipeline(device, device->
|
|
3072
|
+
|
|
3073
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
|
|
3074
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
|
3075
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
|
|
3076
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
|
3077
|
+
|
|
2855
3078
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2856
3079
|
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2857
3080
|
|
|
@@ -2921,22 +3144,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2921
3144
|
};
|
|
2922
3145
|
|
|
2923
3146
|
bool rte = device->float_controls_rte_fp16;
|
|
2924
|
-
#define CREATE_BINARY(name, namemod, spec) \
|
|
3147
|
+
#define CREATE_BINARY(name, namemod, spec, bindings) \
|
|
2925
3148
|
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
|
|
2926
3149
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
|
|
2927
3150
|
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
|
|
2928
|
-
"main",
|
|
2929
|
-
|
|
2930
|
-
CREATE_BINARY(add, , {0})
|
|
2931
|
-
CREATE_BINARY(add, _norepeat, {1})
|
|
2932
|
-
CREATE_BINARY(sub, , {0})
|
|
2933
|
-
CREATE_BINARY(sub, _norepeat, {1})
|
|
2934
|
-
CREATE_BINARY(mul, , {0})
|
|
2935
|
-
CREATE_BINARY(mul, _norepeat, {1})
|
|
2936
|
-
CREATE_BINARY(div, , {0})
|
|
2937
|
-
CREATE_BINARY(div, _norepeat, {1})
|
|
3151
|
+
"main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
|
|
3152
|
+
|
|
3153
|
+
CREATE_BINARY(add, , {0}, 4)
|
|
3154
|
+
CREATE_BINARY(add, _norepeat, {1}, 4)
|
|
3155
|
+
CREATE_BINARY(sub, , {0}, 3)
|
|
3156
|
+
CREATE_BINARY(sub, _norepeat, {1}, 3)
|
|
3157
|
+
CREATE_BINARY(mul, , {0}, 3)
|
|
3158
|
+
CREATE_BINARY(mul, _norepeat, {1}, 3)
|
|
3159
|
+
CREATE_BINARY(div, , {0}, 3)
|
|
3160
|
+
CREATE_BINARY(div, _norepeat, {1}, 3)
|
|
3161
|
+
CREATE_BINARY(add_rms, , {0}, 4)
|
|
3162
|
+
CREATE_BINARY(add_rms, _norepeat, {1}, 4)
|
|
2938
3163
|
#undef CREATE_BINARY
|
|
2939
3164
|
|
|
3165
|
+
if (device->multi_add) {
|
|
3166
|
+
for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
|
|
3167
|
+
ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
|
|
3168
|
+
ggml_vk_create_pipeline(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
|
|
3169
|
+
}
|
|
3170
|
+
}
|
|
3171
|
+
|
|
3172
|
+
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
|
|
3173
|
+
|
|
2940
3174
|
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
2941
3175
|
|
|
2942
3176
|
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -2950,6 +3184,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2950
3184
|
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2951
3185
|
|
|
2952
3186
|
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
3187
|
+
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2953
3188
|
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2954
3189
|
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2955
3190
|
|
|
@@ -2966,6 +3201,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2966
3201
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
|
2967
3202
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2968
3203
|
|
|
3204
|
+
CREATE_UNARY(exp)
|
|
2969
3205
|
CREATE_UNARY(gelu)
|
|
2970
3206
|
CREATE_UNARY(gelu_erf)
|
|
2971
3207
|
CREATE_UNARY(gelu_quick)
|
|
@@ -2987,6 +3223,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2987
3223
|
CREATE_GLU(geglu)
|
|
2988
3224
|
CREATE_GLU(reglu)
|
|
2989
3225
|
CREATE_GLU(swiglu)
|
|
3226
|
+
CREATE_GLU(swiglu_oai)
|
|
2990
3227
|
CREATE_GLU(geglu_erf)
|
|
2991
3228
|
CREATE_GLU(geglu_quick)
|
|
2992
3229
|
#undef CREATE_GLU
|
|
@@ -2996,10 +3233,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2996
3233
|
|
|
2997
3234
|
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
|
2998
3235
|
|
|
2999
|
-
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main",
|
|
3000
|
-
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main",
|
|
3001
|
-
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main",
|
|
3002
|
-
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main",
|
|
3236
|
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
3237
|
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
|
3238
|
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
3239
|
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
|
3003
3240
|
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
3004
3241
|
|
|
3005
3242
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
@@ -3019,11 +3256,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3019
3256
|
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
3020
3257
|
}
|
|
3021
3258
|
|
|
3022
|
-
|
|
3259
|
+
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
|
|
3260
|
+
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
|
|
3261
|
+
}
|
|
3023
3262
|
|
|
3024
3263
|
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
3025
3264
|
|
|
3026
|
-
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(
|
|
3265
|
+
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
3027
3266
|
|
|
3028
3267
|
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
|
3029
3268
|
|
|
@@ -3046,44 +3285,114 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
3046
3285
|
|
|
3047
3286
|
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
3048
3287
|
|
|
3288
|
+
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
3289
|
+
|
|
3049
3290
|
// conv2d
|
|
3050
|
-
uint32_t
|
|
3051
|
-
|
|
3052
|
-
|
|
3053
|
-
|
|
3054
|
-
|
|
3055
|
-
|
|
3056
|
-
|
|
3057
|
-
|
|
3058
|
-
|
|
3059
|
-
|
|
3060
|
-
|
|
3061
|
-
|
|
3062
|
-
|
|
3063
|
-
|
|
3064
|
-
|
|
3065
|
-
|
|
3066
|
-
|
|
3067
|
-
|
|
3068
|
-
|
|
3069
|
-
}
|
|
3070
|
-
|
|
3071
|
-
|
|
3072
|
-
|
|
3073
|
-
|
|
3074
|
-
|
|
3075
|
-
|
|
3076
|
-
|
|
3077
|
-
|
|
3078
|
-
|
|
3079
|
-
|
|
3080
|
-
|
|
3081
|
-
|
|
3082
|
-
|
|
3291
|
+
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
|
3292
|
+
uint32_t conv2d_WG_SIZE = 256;
|
|
3293
|
+
uint32_t conv2d_BS_K = 128;
|
|
3294
|
+
uint32_t conv2d_BS_CRS = 16;
|
|
3295
|
+
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
|
|
3296
|
+
uint32_t conv2d_BS_NPQ = 128;
|
|
3297
|
+
uint32_t conv2d_TS_K = 8;
|
|
3298
|
+
uint32_t conv2d_SHMEM_PAD = 4;
|
|
3299
|
+
bool conv2d_UNROLL = true;
|
|
3300
|
+
|
|
3301
|
+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
3302
|
+
if (device->coopmat2) {
|
|
3303
|
+
conv2d_SHMEM_PAD = 8; // 8 float16_t
|
|
3304
|
+
}
|
|
3305
|
+
#endif
|
|
3306
|
+
|
|
3307
|
+
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
|
|
3308
|
+
conv2d_SHMEM_PAD = 0;
|
|
3309
|
+
conv2d_UNROLL = false;
|
|
3310
|
+
} else if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
|
3311
|
+
conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4;
|
|
3312
|
+
}
|
|
3313
|
+
|
|
3314
|
+
switch (s) {
|
|
3315
|
+
default:
|
|
3316
|
+
case CONV_SHAPE_128x128:
|
|
3317
|
+
conv2d_BS_K = 128;
|
|
3318
|
+
conv2d_BS_NPQ = 128;
|
|
3319
|
+
conv2d_BS_CRS = 16;
|
|
3320
|
+
if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) {
|
|
3321
|
+
conv2d_UNROLL = false;
|
|
3322
|
+
}
|
|
3323
|
+
break;
|
|
3324
|
+
case CONV_SHAPE_64x32:
|
|
3325
|
+
conv2d_BS_K = 64;
|
|
3326
|
+
conv2d_BS_NPQ = 32;
|
|
3327
|
+
conv2d_BS_CRS = 32;
|
|
3328
|
+
conv2d_TS_K = 4;
|
|
3329
|
+
break;
|
|
3330
|
+
case CONV_SHAPE_32x256:
|
|
3331
|
+
conv2d_BS_K = 32;
|
|
3332
|
+
conv2d_BS_NPQ = 256;
|
|
3333
|
+
conv2d_BS_CRS = 16;
|
|
3334
|
+
break;
|
|
3335
|
+
}
|
|
3336
|
+
|
|
3337
|
+
// Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math.
|
|
3338
|
+
bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA ||
|
|
3339
|
+
device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
|
|
3340
|
+
bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD ||
|
|
3341
|
+
device->architecture == vk_device_architecture::AMD_GCN;
|
|
3342
|
+
|
|
3343
|
+
if (device->subgroup_shuffle &&
|
|
3344
|
+
device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316.
|
|
3345
|
+
allow_collectives_nv &&
|
|
3346
|
+
allow_collectives_amd) {
|
|
3347
|
+
use_collectives = 1;
|
|
3348
|
+
conv2d_BS_CRS = std::min(
|
|
3349
|
+
device->subgroup_size,
|
|
3350
|
+
conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
|
|
3351
|
+
}
|
|
3352
|
+
|
|
3353
|
+
uint32_t conv2d_shmem_req =
|
|
3354
|
+
(conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
|
|
3355
|
+
if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
|
|
3356
|
+
conv2d_BS_CRS = 8;
|
|
3357
|
+
if (use_collectives) {
|
|
3358
|
+
conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
|
|
3359
|
+
}
|
|
3360
|
+
}
|
|
3361
|
+
|
|
3362
|
+
std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
|
|
3363
|
+
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
|
|
3364
|
+
|
|
3365
|
+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
3366
|
+
if (device->coopmat2) {
|
|
3367
|
+
ggml_vk_create_pipeline(
|
|
3368
|
+
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3,
|
|
3369
|
+
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
|
|
3370
|
+
ggml_vk_create_pipeline(
|
|
3371
|
+
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3,
|
|
3372
|
+
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
|
|
3373
|
+
} else
|
|
3374
|
+
#endif
|
|
3375
|
+
if (conv2d_UNROLL) {
|
|
3376
|
+
ggml_vk_create_pipeline(
|
|
3377
|
+
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
|
|
3378
|
+
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
|
|
3379
|
+
ggml_vk_create_pipeline(
|
|
3380
|
+
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3,
|
|
3381
|
+
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
|
|
3382
|
+
} else {
|
|
3383
|
+
ggml_vk_create_pipeline(
|
|
3384
|
+
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
|
|
3385
|
+
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
|
|
3386
|
+
ggml_vk_create_pipeline(
|
|
3387
|
+
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
|
|
3388
|
+
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
|
|
3389
|
+
}
|
|
3083
3390
|
}
|
|
3084
3391
|
|
|
3085
3392
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
3086
3393
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
3394
|
+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
3395
|
+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
3087
3396
|
|
|
3088
3397
|
for (auto &c : compiles) {
|
|
3089
3398
|
c.wait();
|
|
@@ -3125,6 +3434,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
3125
3434
|
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
|
|
3126
3435
|
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
|
|
3127
3436
|
|
|
3437
|
+
const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM");
|
|
3438
|
+
device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr;
|
|
3439
|
+
|
|
3128
3440
|
bool fp16_storage = false;
|
|
3129
3441
|
bool fp16_compute = false;
|
|
3130
3442
|
bool maintenance4_support = false;
|
|
@@ -3269,6 +3581,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
3269
3581
|
device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
3270
3582
|
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
|
|
3271
3583
|
|
|
3584
|
+
device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
3585
|
+
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);
|
|
3586
|
+
|
|
3272
3587
|
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
|
|
3273
3588
|
|
|
3274
3589
|
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
|
@@ -3402,6 +3717,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
3402
3717
|
|
|
3403
3718
|
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
|
|
3404
3719
|
|
|
3720
|
+
device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
|
|
3721
|
+
device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
|
|
3722
|
+
vk12_features.runtimeDescriptorArray &&
|
|
3723
|
+
device->vendor_id != VK_VENDOR_ID_INTEL &&
|
|
3724
|
+
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
|
|
3725
|
+
|
|
3405
3726
|
if (device->subgroup_size_control) {
|
|
3406
3727
|
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
|
|
3407
3728
|
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
|
|
@@ -3412,9 +3733,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
3412
3733
|
(subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
3413
3734
|
subgroup_size_control_features.subgroupSizeControl;
|
|
3414
3735
|
|
|
3415
|
-
|
|
3416
|
-
device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
|
|
3417
|
-
}
|
|
3736
|
+
device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
|
|
3418
3737
|
|
|
3419
3738
|
#if defined(VK_KHR_cooperative_matrix)
|
|
3420
3739
|
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
|
|
@@ -3715,6 +4034,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
3715
4034
|
|
|
3716
4035
|
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
|
|
3717
4036
|
|
|
4037
|
+
device->add_rms_fusion = !device->disable_fusion &&
|
|
4038
|
+
device->subgroup_add &&
|
|
4039
|
+
device->vendor_id != VK_VENDOR_ID_INTEL;
|
|
4040
|
+
device->partials_binding_alignment =
|
|
4041
|
+
std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
|
|
4042
|
+
|
|
3718
4043
|
return device;
|
|
3719
4044
|
}
|
|
3720
4045
|
|
|
@@ -4139,6 +4464,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
|
|
4139
4464
|
case GGML_TYPE_IQ3_S:
|
|
4140
4465
|
case GGML_TYPE_IQ4_XS:
|
|
4141
4466
|
case GGML_TYPE_IQ4_NL:
|
|
4467
|
+
case GGML_TYPE_MXFP4:
|
|
4142
4468
|
break;
|
|
4143
4469
|
default:
|
|
4144
4470
|
return nullptr;
|
|
@@ -4209,6 +4535,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
|
4209
4535
|
case GGML_TYPE_IQ3_S:
|
|
4210
4536
|
case GGML_TYPE_IQ4_XS:
|
|
4211
4537
|
case GGML_TYPE_IQ4_NL:
|
|
4538
|
+
case GGML_TYPE_MXFP4:
|
|
4212
4539
|
break;
|
|
4213
4540
|
default:
|
|
4214
4541
|
return nullptr;
|
|
@@ -4224,7 +4551,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
|
4224
4551
|
return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
|
4225
4552
|
}
|
|
4226
4553
|
|
|
4227
|
-
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) {
|
|
4554
|
+
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) {
|
|
4228
4555
|
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
|
|
4229
4556
|
GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
|
|
4230
4557
|
GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
|
|
@@ -4252,12 +4579,30 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|
|
4252
4579
|
case GGML_TYPE_IQ3_S:
|
|
4253
4580
|
case GGML_TYPE_IQ4_XS:
|
|
4254
4581
|
case GGML_TYPE_IQ4_NL:
|
|
4582
|
+
case GGML_TYPE_MXFP4:
|
|
4255
4583
|
break;
|
|
4256
4584
|
default:
|
|
4257
4585
|
return nullptr;
|
|
4258
4586
|
}
|
|
4259
4587
|
|
|
4260
|
-
|
|
4588
|
+
// heuristic to choose workgroup size
|
|
4589
|
+
uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
|
|
4590
|
+
if (ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
|
|
4591
|
+
// Prefer larger workgroups when M is small, to spread the work out more
|
|
4592
|
+
// and keep more SMs busy.
|
|
4593
|
+
// q6_k seems to prefer small workgroup size even for "medium" values of M.
|
|
4594
|
+
if (a_type == GGML_TYPE_Q6_K) {
|
|
4595
|
+
if (m < 4096 && k >= 1024) {
|
|
4596
|
+
dmmv_wg = DMMV_WG_SIZE_LARGE;
|
|
4597
|
+
}
|
|
4598
|
+
} else {
|
|
4599
|
+
if (m <= 8192 && k >= 1024) {
|
|
4600
|
+
dmmv_wg = DMMV_WG_SIZE_LARGE;
|
|
4601
|
+
}
|
|
4602
|
+
}
|
|
4603
|
+
}
|
|
4604
|
+
|
|
4605
|
+
return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1];
|
|
4261
4606
|
}
|
|
4262
4607
|
|
|
4263
4608
|
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
|
|
@@ -4306,12 +4651,23 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|
|
4306
4651
|
case GGML_TYPE_IQ3_S:
|
|
4307
4652
|
case GGML_TYPE_IQ4_XS:
|
|
4308
4653
|
case GGML_TYPE_IQ4_NL:
|
|
4654
|
+
case GGML_TYPE_MXFP4:
|
|
4309
4655
|
break;
|
|
4310
4656
|
default:
|
|
4311
4657
|
return nullptr;
|
|
4312
4658
|
}
|
|
4313
4659
|
|
|
4314
|
-
|
|
4660
|
+
// XXX TODO 'prec' is not actually allowed in mul_mat_id.
|
|
4661
|
+
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
|
|
4662
|
+
bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr;
|
|
4663
|
+
bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr;
|
|
4664
|
+
|
|
4665
|
+
if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
|
|
4666
|
+
return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc;
|
|
4667
|
+
} else {
|
|
4668
|
+
GGML_ASSERT(support_fp32acc);
|
|
4669
|
+
return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
|
|
4670
|
+
}
|
|
4315
4671
|
}
|
|
4316
4672
|
|
|
4317
4673
|
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
|
|
@@ -4341,6 +4697,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
|
|
4341
4697
|
case GGML_TYPE_IQ3_S:
|
|
4342
4698
|
case GGML_TYPE_IQ4_XS:
|
|
4343
4699
|
case GGML_TYPE_IQ4_NL:
|
|
4700
|
+
case GGML_TYPE_MXFP4:
|
|
4344
4701
|
break;
|
|
4345
4702
|
default:
|
|
4346
4703
|
return nullptr;
|
|
@@ -4526,6 +4883,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
|
|
|
4526
4883
|
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
|
|
4527
4884
|
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
|
|
4528
4885
|
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
|
|
4886
|
+
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
|
|
4529
4887
|
|
|
4530
4888
|
vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
|
|
4531
4889
|
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
|
|
@@ -4648,7 +5006,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
|
|
|
4648
5006
|
}
|
|
4649
5007
|
}
|
|
4650
5008
|
|
|
4651
|
-
ggml_vk_sync_buffers(subctx);
|
|
5009
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
4652
5010
|
subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
|
|
4653
5011
|
return;
|
|
4654
5012
|
}
|
|
@@ -4663,7 +5021,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
|
|
|
4663
5021
|
ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
|
|
4664
5022
|
VkBufferCopy buf_copy{ 0, offset, copy_size };
|
|
4665
5023
|
|
|
4666
|
-
ggml_vk_sync_buffers(subctx);
|
|
5024
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
4667
5025
|
vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
|
|
4668
5026
|
|
|
4669
5027
|
for (uint64_t i3 = 0; i3 < ne3; i3++) {
|
|
@@ -4717,7 +5075,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
|
|
|
4717
5075
|
}
|
|
4718
5076
|
}
|
|
4719
5077
|
|
|
4720
|
-
ggml_vk_sync_buffers(subctx);
|
|
5078
|
+
ggml_vk_sync_buffers(nullptr, subctx);
|
|
4721
5079
|
subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
|
|
4722
5080
|
return;
|
|
4723
5081
|
}
|
|
@@ -4738,7 +5096,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
|
|
|
4738
5096
|
offset,
|
|
4739
5097
|
copy_size};
|
|
4740
5098
|
|
|
4741
|
-
ggml_vk_sync_buffers(subctx);
|
|
5099
|
+
ggml_vk_sync_buffers(nullptr, subctx);
|
|
4742
5100
|
vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
|
|
4743
5101
|
|
|
4744
5102
|
if (width == spitch) {
|
|
@@ -4818,7 +5176,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
|
|
|
4818
5176
|
|
|
4819
5177
|
if (buf != nullptr) {
|
|
4820
5178
|
// Memory is pinned, use as staging buffer
|
|
4821
|
-
ggml_vk_sync_buffers(subctx);
|
|
5179
|
+
ggml_vk_sync_buffers(nullptr, subctx);
|
|
4822
5180
|
subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
|
|
4823
5181
|
|
|
4824
5182
|
return;
|
|
@@ -4835,7 +5193,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
|
|
|
4835
5193
|
|
|
4836
5194
|
vk_buffer& staging_buffer = src->device->sync_staging;
|
|
4837
5195
|
|
|
4838
|
-
ggml_vk_sync_buffers(subctx);
|
|
5196
|
+
ggml_vk_sync_buffers(nullptr, subctx);
|
|
4839
5197
|
subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
|
|
4840
5198
|
|
|
4841
5199
|
deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
|
|
@@ -4933,26 +5291,37 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
|
|
|
4933
5291
|
ggml_vk_queue_command_pools_cleanup(dst->device);
|
|
4934
5292
|
}
|
|
4935
5293
|
|
|
4936
|
-
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx,
|
|
5294
|
+
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) {
|
|
4937
5295
|
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
|
|
4938
5296
|
|
|
4939
5297
|
uint32_t split_k = 1;
|
|
4940
|
-
if (ctx->device->shader_core_count != 0 && m >=
|
|
5298
|
+
if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
|
|
4941
5299
|
// If k is 'large' and the SMs will fill less than halfway, use split_k.
|
|
4942
5300
|
uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
|
|
4943
5301
|
uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
|
|
4944
|
-
|
|
4945
|
-
|
|
4946
|
-
|
|
4947
|
-
|
|
4948
|
-
if (
|
|
4949
|
-
split_k =
|
|
5302
|
+
|
|
5303
|
+
if (k >= 2048) {
|
|
5304
|
+
if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) {
|
|
5305
|
+
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
|
|
5306
|
+
} else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) {
|
|
5307
|
+
split_k = 3;
|
|
4950
5308
|
}
|
|
4951
|
-
|
|
4952
|
-
|
|
4953
|
-
|
|
4954
|
-
|
|
5309
|
+
// Cap the split at 8x. Unless k is huge this is a lot of overhead.
|
|
5310
|
+
split_k = std::min(split_k, 8u);
|
|
5311
|
+
|
|
5312
|
+
// ggml_vk_matmul will align the splits to be a multiple of 256.
|
|
5313
|
+
// If this rounded up size would cause the last split to be empty,
|
|
5314
|
+
// then reduce the split count.
|
|
5315
|
+
while (true) {
|
|
5316
|
+
if (split_k == 1) {
|
|
5317
|
+
break;
|
|
4955
5318
|
}
|
|
5319
|
+
uint32_t k_split = CEIL_DIV(k, split_k);
|
|
5320
|
+
k_split = ROUNDUP_POW2(k_split, 256);
|
|
5321
|
+
if (k_split * (split_k - 1) < k) {
|
|
5322
|
+
break;
|
|
5323
|
+
}
|
|
5324
|
+
split_k--;
|
|
4956
5325
|
}
|
|
4957
5326
|
}
|
|
4958
5327
|
}
|
|
@@ -4964,9 +5333,22 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
|
|
|
4964
5333
|
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
|
4965
5334
|
|
|
4966
5335
|
if (ctx->device->coopmat2) {
|
|
5336
|
+
const uint32_t shader_core_count = ctx->device->shader_core_count;
|
|
5337
|
+
const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
|
|
5338
|
+
const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]);
|
|
5339
|
+
|
|
4967
5340
|
// Use large shader when the N dimension is greater than the medium shader's tile size
|
|
4968
5341
|
uint32_t crossover_large = mmp->m->wg_denoms[1];
|
|
4969
|
-
|
|
5342
|
+
|
|
5343
|
+
// Prefer large over medium if either:
|
|
5344
|
+
// - medium or large tiles would overfill the GPU
|
|
5345
|
+
// - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not
|
|
5346
|
+
// (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)
|
|
5347
|
+
bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||
|
|
5348
|
+
// split_k==3 with large tiles likely better than medium tiles with no split_k.
|
|
5349
|
+
(tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
|
|
5350
|
+
|
|
5351
|
+
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
|
|
4970
5352
|
return aligned ? mmp->a_l : mmp->l;
|
|
4971
5353
|
}
|
|
4972
5354
|
// Use medium shader when the N dimension is greater than the small shader's tile size
|
|
@@ -5001,21 +5383,29 @@ static void ggml_vk_matmul(
|
|
|
5001
5383
|
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
|
|
5002
5384
|
uint32_t padded_n) {
|
|
5003
5385
|
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
|
|
5004
|
-
ggml_vk_sync_buffers(subctx);
|
|
5005
5386
|
if (split_k == 1) {
|
|
5006
5387
|
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
|
5007
5388
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
|
|
5008
5389
|
return;
|
|
5009
5390
|
}
|
|
5010
5391
|
|
|
5392
|
+
if (ctx->prealloc_split_k_need_sync) {
|
|
5393
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
5394
|
+
}
|
|
5395
|
+
|
|
5011
5396
|
GGML_ASSERT(batch_stride_d == m * n);
|
|
5012
5397
|
|
|
5013
|
-
|
|
5398
|
+
// Round the split size up to a multiple of 256 (k-quant alignment)
|
|
5399
|
+
uint32_t k_split = CEIL_DIV(k, split_k);
|
|
5400
|
+
k_split = ROUNDUP_POW2(k_split, 256);
|
|
5401
|
+
|
|
5402
|
+
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
|
|
5014
5403
|
// Make sure enough workgroups get assigned for split k to work
|
|
5015
5404
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
|
|
5016
|
-
ggml_vk_sync_buffers(subctx);
|
|
5405
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
5017
5406
|
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
|
|
5018
5407
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
|
|
5408
|
+
ctx->prealloc_split_k_need_sync = true;
|
|
5019
5409
|
}
|
|
5020
5410
|
|
|
5021
5411
|
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
|
|
@@ -5060,7 +5450,6 @@ static void ggml_vk_matmul_id(
|
|
|
5060
5450
|
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
|
|
5061
5451
|
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
|
|
5062
5452
|
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
|
|
5063
|
-
ggml_vk_sync_buffers(subctx);
|
|
5064
5453
|
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
|
|
5065
5454
|
nei0, nei1, nbi1, ne11, padded_n };
|
|
5066
5455
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as });
|
|
@@ -5191,8 +5580,8 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
|
|
|
5191
5580
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
5192
5581
|
};
|
|
5193
5582
|
init_pushconst_fastdiv(pc);
|
|
5194
|
-
ggml_vk_sync_buffers(subctx);
|
|
5195
5583
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
|
|
5584
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
5196
5585
|
}
|
|
5197
5586
|
|
|
5198
5587
|
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
|
|
@@ -5210,14 +5599,14 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
5210
5599
|
|
|
5211
5600
|
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
|
5212
5601
|
|
|
5213
|
-
ggml_vk_sync_buffers(subctx);
|
|
5214
5602
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
|
|
5603
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
5215
5604
|
}
|
|
5216
5605
|
|
|
5217
5606
|
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
5218
|
-
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
|
5219
|
-
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
|
5220
|
-
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
5607
|
+
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
|
5608
|
+
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
|
5609
|
+
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
5221
5610
|
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
|
|
5222
5611
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
|
|
5223
5612
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
|
@@ -5406,18 +5795,39 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
5406
5795
|
GGML_ASSERT(qy_sz == y_sz);
|
|
5407
5796
|
}
|
|
5408
5797
|
|
|
5798
|
+
if (x_non_contig || qx_needs_dequant) {
|
|
5799
|
+
if (ctx->prealloc_x_need_sync) {
|
|
5800
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
5801
|
+
}
|
|
5802
|
+
}
|
|
5803
|
+
if (y_non_contig || quantize_y) {
|
|
5804
|
+
if (ctx->prealloc_y_need_sync) {
|
|
5805
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
5806
|
+
}
|
|
5807
|
+
}
|
|
5808
|
+
|
|
5409
5809
|
if (x_non_contig) {
|
|
5410
5810
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
|
|
5411
5811
|
} else if (qx_needs_dequant) {
|
|
5412
5812
|
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
|
|
5413
|
-
ggml_vk_sync_buffers(subctx);
|
|
5414
5813
|
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
|
5814
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
5415
5815
|
}
|
|
5416
5816
|
if (y_non_contig) {
|
|
5417
|
-
|
|
5817
|
+
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
|
|
5818
|
+
ctx->prealloc_y_last_tensor_used != src1) {
|
|
5819
|
+
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
|
5820
|
+
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
|
5821
|
+
ctx->prealloc_y_last_tensor_used = src1;
|
|
5822
|
+
}
|
|
5418
5823
|
}
|
|
5419
5824
|
if (quantize_y) {
|
|
5420
|
-
|
|
5825
|
+
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
|
|
5826
|
+
ctx->prealloc_y_last_tensor_used != src1) {
|
|
5827
|
+
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
|
|
5828
|
+
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
|
|
5829
|
+
ctx->prealloc_y_last_tensor_used = src1;
|
|
5830
|
+
}
|
|
5421
5831
|
}
|
|
5422
5832
|
|
|
5423
5833
|
uint32_t stride_batch_x = ne00*ne01;
|
|
@@ -5440,6 +5850,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
5440
5850
|
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
|
|
5441
5851
|
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
|
|
5442
5852
|
); // NOLINT
|
|
5853
|
+
|
|
5854
|
+
if (x_non_contig || qx_needs_dequant) {
|
|
5855
|
+
ctx->prealloc_x_need_sync = true;
|
|
5856
|
+
}
|
|
5857
|
+
if (y_non_contig || quantize_y) {
|
|
5858
|
+
ctx->prealloc_y_need_sync = true;
|
|
5859
|
+
}
|
|
5443
5860
|
}
|
|
5444
5861
|
|
|
5445
5862
|
static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
@@ -5523,7 +5940,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
5523
5940
|
} else {
|
|
5524
5941
|
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
|
5525
5942
|
}
|
|
5526
|
-
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11);
|
|
5943
|
+
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00);
|
|
5527
5944
|
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
|
5528
5945
|
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
|
5529
5946
|
GGML_ASSERT(dmmv != nullptr);
|
|
@@ -5586,13 +6003,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
5586
6003
|
GGML_ASSERT(qy_sz == y_sz);
|
|
5587
6004
|
}
|
|
5588
6005
|
|
|
6006
|
+
if (x_non_contig) {
|
|
6007
|
+
if (ctx->prealloc_x_need_sync) {
|
|
6008
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
6009
|
+
}
|
|
6010
|
+
}
|
|
6011
|
+
if (y_non_contig) {
|
|
6012
|
+
if (ctx->prealloc_y_need_sync) {
|
|
6013
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
6014
|
+
}
|
|
6015
|
+
}
|
|
6016
|
+
|
|
5589
6017
|
if (x_non_contig) {
|
|
5590
6018
|
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
|
|
5591
6019
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
|
|
5592
6020
|
}
|
|
5593
6021
|
if (y_non_contig) {
|
|
5594
6022
|
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
|
|
5595
|
-
|
|
6023
|
+
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
|
|
6024
|
+
ctx->prealloc_y_last_tensor_used != src1) {
|
|
6025
|
+
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
|
6026
|
+
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
|
6027
|
+
ctx->prealloc_y_last_tensor_used = src1;
|
|
6028
|
+
}
|
|
5596
6029
|
}
|
|
5597
6030
|
|
|
5598
6031
|
// For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
|
|
@@ -5624,10 +6057,16 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
5624
6057
|
stride_batch_x, stride_batch_y, stride_batch_d,
|
|
5625
6058
|
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
|
5626
6059
|
};
|
|
5627
|
-
ggml_vk_sync_buffers(subctx);
|
|
5628
6060
|
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
|
5629
6061
|
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
|
|
5630
6062
|
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
|
|
6063
|
+
|
|
6064
|
+
if (x_non_contig) {
|
|
6065
|
+
ctx->prealloc_x_need_sync = true;
|
|
6066
|
+
}
|
|
6067
|
+
if (y_non_contig) {
|
|
6068
|
+
ctx->prealloc_y_need_sync = true;
|
|
6069
|
+
}
|
|
5631
6070
|
}
|
|
5632
6071
|
|
|
5633
6072
|
static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
@@ -5714,7 +6153,6 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
|
|
|
5714
6153
|
workgroups_z /= gqa_ratio;
|
|
5715
6154
|
}
|
|
5716
6155
|
|
|
5717
|
-
ggml_vk_sync_buffers(subctx);
|
|
5718
6156
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z });
|
|
5719
6157
|
}
|
|
5720
6158
|
|
|
@@ -5732,7 +6170,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|
|
5732
6170
|
const uint64_t ne00 = src0->ne[0];
|
|
5733
6171
|
const uint64_t ne01 = src0->ne[1];
|
|
5734
6172
|
const uint64_t ne02 = src0->ne[2];
|
|
5735
|
-
|
|
6173
|
+
const uint64_t ne03 = src0->ne[3];
|
|
5736
6174
|
|
|
5737
6175
|
const uint64_t nb01 = src0->nb[1];
|
|
5738
6176
|
const uint64_t nb02 = src0->nb[2];
|
|
@@ -5744,7 +6182,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|
|
5744
6182
|
const uint64_t ne12 = src1->ne[2];
|
|
5745
6183
|
// const uint64_t ne13 = src1->ne[3];
|
|
5746
6184
|
|
|
6185
|
+
const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t));
|
|
6186
|
+
const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float));
|
|
6187
|
+
const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float));
|
|
6188
|
+
|
|
5747
6189
|
GGML_ASSERT(ne11 == 1);
|
|
6190
|
+
GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op
|
|
5748
6191
|
|
|
5749
6192
|
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
|
5750
6193
|
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
|
@@ -5760,7 +6203,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|
|
5760
6203
|
src1_uma = d_Qy != nullptr;
|
|
5761
6204
|
}
|
|
5762
6205
|
|
|
5763
|
-
const uint64_t d_ne = ne01 * ne11 * ne12;
|
|
6206
|
+
const uint64_t d_ne = ne01 * ne11 * ne12 * ne03;
|
|
5764
6207
|
|
|
5765
6208
|
const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
|
|
5766
6209
|
const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
|
|
@@ -5795,10 +6238,9 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|
|
5795
6238
|
const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
|
|
5796
6239
|
|
|
5797
6240
|
// compute
|
|
5798
|
-
const std::array<uint32_t,
|
|
5799
|
-
ggml_vk_sync_buffers(subctx);
|
|
6241
|
+
const std::array<uint32_t, 12> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 };
|
|
5800
6242
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
|
|
5801
|
-
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, {
|
|
6243
|
+
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
|
|
5802
6244
|
}
|
|
5803
6245
|
|
|
5804
6246
|
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
@@ -5847,7 +6289,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
5847
6289
|
|
|
5848
6290
|
const uint64_t nei0 = ids->ne[0];
|
|
5849
6291
|
const uint64_t nei1 = ids->ne[1];
|
|
5850
|
-
GGML_ASSERT(nei0 * nei1 <= 4096);
|
|
5851
6292
|
|
|
5852
6293
|
const uint32_t nbi1 = ids->nb[1];
|
|
5853
6294
|
const uint32_t nbi2 = ids->nb[2];
|
|
@@ -6008,16 +6449,32 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
6008
6449
|
GGML_ASSERT(qy_sz == y_sz);
|
|
6009
6450
|
}
|
|
6010
6451
|
|
|
6452
|
+
if (x_non_contig || qx_needs_dequant) {
|
|
6453
|
+
if (ctx->prealloc_x_need_sync) {
|
|
6454
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
6455
|
+
}
|
|
6456
|
+
}
|
|
6457
|
+
if (y_non_contig) {
|
|
6458
|
+
if (ctx->prealloc_y_need_sync) {
|
|
6459
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
6460
|
+
}
|
|
6461
|
+
}
|
|
6462
|
+
|
|
6011
6463
|
if (x_non_contig) {
|
|
6012
6464
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
|
|
6013
6465
|
} else if (qx_needs_dequant) {
|
|
6014
6466
|
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
|
|
6015
|
-
ggml_vk_sync_buffers(subctx);
|
|
6016
6467
|
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
|
|
6017
6468
|
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
|
6469
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
6018
6470
|
}
|
|
6019
6471
|
if (y_non_contig) {
|
|
6020
|
-
|
|
6472
|
+
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
|
|
6473
|
+
ctx->prealloc_y_last_tensor_used != src1) {
|
|
6474
|
+
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
|
6475
|
+
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
|
6476
|
+
ctx->prealloc_y_last_tensor_used = src1;
|
|
6477
|
+
}
|
|
6021
6478
|
}
|
|
6022
6479
|
|
|
6023
6480
|
uint32_t stride_batch_x = ne00*ne01;
|
|
@@ -6040,6 +6497,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
6040
6497
|
stride_batch_x, stride_batch_y, ne20*ne21,
|
|
6041
6498
|
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
|
|
6042
6499
|
); // NOLINT
|
|
6500
|
+
|
|
6501
|
+
if (x_non_contig || qx_needs_dequant) {
|
|
6502
|
+
ctx->prealloc_x_need_sync = true;
|
|
6503
|
+
}
|
|
6504
|
+
if (y_non_contig) {
|
|
6505
|
+
ctx->prealloc_y_need_sync = true;
|
|
6506
|
+
}
|
|
6043
6507
|
}
|
|
6044
6508
|
|
|
6045
6509
|
static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) {
|
|
@@ -6199,13 +6663,29 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
6199
6663
|
GGML_ASSERT(qy_sz == y_sz);
|
|
6200
6664
|
}
|
|
6201
6665
|
|
|
6666
|
+
if (x_non_contig) {
|
|
6667
|
+
if (ctx->prealloc_x_need_sync) {
|
|
6668
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
6669
|
+
}
|
|
6670
|
+
}
|
|
6671
|
+
if (y_non_contig) {
|
|
6672
|
+
if (ctx->prealloc_y_need_sync) {
|
|
6673
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
6674
|
+
}
|
|
6675
|
+
}
|
|
6676
|
+
|
|
6202
6677
|
if (x_non_contig) {
|
|
6203
6678
|
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
|
|
6204
6679
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
|
|
6205
6680
|
}
|
|
6206
6681
|
if (y_non_contig) {
|
|
6207
6682
|
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
|
|
6208
|
-
|
|
6683
|
+
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
|
|
6684
|
+
ctx->prealloc_y_last_tensor_used != src1) {
|
|
6685
|
+
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
|
6686
|
+
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
|
6687
|
+
ctx->prealloc_y_last_tensor_used = src1;
|
|
6688
|
+
}
|
|
6209
6689
|
}
|
|
6210
6690
|
|
|
6211
6691
|
uint32_t stride_batch_y = ne10*ne11;
|
|
@@ -6230,11 +6710,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
6230
6710
|
(uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
|
|
6231
6711
|
(uint32_t)nei0, (uint32_t)ne11,
|
|
6232
6712
|
};
|
|
6233
|
-
ggml_vk_sync_buffers(subctx);
|
|
6234
6713
|
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
|
6235
6714
|
{ vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
|
|
6236
6715
|
vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } },
|
|
6237
6716
|
pc, { groups_x, (uint32_t)nei0, groups_z });
|
|
6717
|
+
|
|
6718
|
+
if (x_non_contig) {
|
|
6719
|
+
ctx->prealloc_x_need_sync = true;
|
|
6720
|
+
}
|
|
6721
|
+
if (y_non_contig) {
|
|
6722
|
+
ctx->prealloc_y_need_sync = true;
|
|
6723
|
+
}
|
|
6238
6724
|
}
|
|
6239
6725
|
|
|
6240
6726
|
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
|
@@ -6242,30 +6728,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6242
6728
|
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
|
6243
6729
|
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
|
|
6244
6730
|
} else {
|
|
6245
|
-
|
|
6246
|
-
const uint32_t nei0 = (uint32_t)src2->ne[0];
|
|
6247
|
-
const uint32_t nei1 = (uint32_t)src2->ne[1];
|
|
6248
|
-
|
|
6249
|
-
GGML_ASSERT(nei0 <= 4096);
|
|
6250
|
-
const uint32_t split_size = std::min(nei1, 4096u / nei0);
|
|
6251
|
-
|
|
6252
|
-
ggml_tensor src1_copy = *src1;
|
|
6253
|
-
ggml_tensor src2_copy = *src2;
|
|
6254
|
-
ggml_tensor dst_copy = *dst;
|
|
6255
|
-
|
|
6256
|
-
for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
|
|
6257
|
-
const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
|
|
6258
|
-
|
|
6259
|
-
src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
|
|
6260
|
-
src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
|
|
6261
|
-
dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
|
|
6262
|
-
|
|
6263
|
-
src1_copy.ne[2] = n_tokens;
|
|
6264
|
-
src2_copy.ne[1] = n_tokens;
|
|
6265
|
-
dst_copy.ne[2] = n_tokens;
|
|
6266
|
-
|
|
6267
|
-
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
|
|
6268
|
-
}
|
|
6731
|
+
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
|
|
6269
6732
|
}
|
|
6270
6733
|
}
|
|
6271
6734
|
|
|
@@ -6298,18 +6761,21 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|
|
6298
6761
|
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
|
|
6299
6762
|
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
6300
6763
|
|
|
6764
|
+
const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
|
|
6765
|
+
|
|
6301
6766
|
const uint32_t acctype = f32acc ? 4 : 2;
|
|
6302
6767
|
const uint32_t f16vec4 = 8;
|
|
6303
6768
|
|
|
6304
6769
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
6305
6770
|
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
|
6306
6771
|
|
|
6307
|
-
const uint32_t
|
|
6772
|
+
const uint32_t qstride = hsk_pad / 4 + 2;
|
|
6773
|
+
const uint32_t Qf = Br * qstride * f16vec4;
|
|
6308
6774
|
|
|
6309
6775
|
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
|
|
6310
6776
|
const uint32_t sfsh = Bc * sfshstride * acctype;
|
|
6311
6777
|
|
|
6312
|
-
const uint32_t kshstride =
|
|
6778
|
+
const uint32_t kshstride = hsk_pad / 4 + 2;
|
|
6313
6779
|
const uint32_t ksh = Bc * kshstride * f16vec4;
|
|
6314
6780
|
|
|
6315
6781
|
const uint32_t slope = Br * sizeof(float);
|
|
@@ -6322,11 +6788,14 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|
|
6322
6788
|
return supported;
|
|
6323
6789
|
}
|
|
6324
6790
|
|
|
6325
|
-
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
|
|
6791
|
+
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst, bool dryrun = false) {
|
|
6326
6792
|
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
|
|
6327
6793
|
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
|
|
6328
6794
|
std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
|
|
6329
6795
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
6796
|
+
if (sinks) {
|
|
6797
|
+
std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3];
|
|
6798
|
+
}
|
|
6330
6799
|
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
|
|
6331
6800
|
|
|
6332
6801
|
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
@@ -6417,7 +6886,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6417
6886
|
workgroups_y /= N;
|
|
6418
6887
|
}
|
|
6419
6888
|
|
|
6420
|
-
vk_pipeline *pipelines;
|
|
6421
6889
|
bool small_rows = N <= get_fa_num_small_rows(path);
|
|
6422
6890
|
|
|
6423
6891
|
// coopmat1 does not actually support "small rows" (it needs 16 rows).
|
|
@@ -6437,37 +6905,36 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6437
6905
|
small_rows = true;
|
|
6438
6906
|
}
|
|
6439
6907
|
|
|
6440
|
-
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
|
6441
|
-
|
|
6442
|
-
FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
|
|
6443
|
-
|
|
6444
|
-
switch (path) {
|
|
6445
|
-
case FA_SCALAR:
|
|
6446
|
-
pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
|
|
6447
|
-
break;
|
|
6448
|
-
case FA_COOPMAT1:
|
|
6449
|
-
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
|
|
6450
|
-
break;
|
|
6451
|
-
case FA_COOPMAT2:
|
|
6452
|
-
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
|
|
6453
|
-
break;
|
|
6454
|
-
default:
|
|
6455
|
-
GGML_ASSERT(0);
|
|
6456
|
-
}
|
|
6457
|
-
assert(pipelines);
|
|
6458
|
-
|
|
6459
6908
|
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
|
6460
6909
|
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
|
6461
6910
|
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
|
6462
6911
|
|
|
6463
|
-
|
|
6912
|
+
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
|
|
6913
|
+
bool aligned = (KV % alignment) == 0 &&
|
|
6464
6914
|
// the "aligned" shader variant will forcibly align strides, for performance
|
|
6465
6915
|
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
|
6466
6916
|
|
|
6917
|
+
// Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
|
|
6918
|
+
if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
|
|
6919
|
+
aligned = false;
|
|
6920
|
+
}
|
|
6467
6921
|
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
|
6468
6922
|
GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
|
|
6469
6923
|
|
|
6470
|
-
|
|
6924
|
+
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
|
6925
|
+
|
|
6926
|
+
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc);
|
|
6927
|
+
|
|
6928
|
+
vk_pipeline pipeline = nullptr;
|
|
6929
|
+
|
|
6930
|
+
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
|
|
6931
|
+
auto it = pipelines.find(fa_pipeline_state);
|
|
6932
|
+
if (it != pipelines.end()) {
|
|
6933
|
+
pipeline = it->second;
|
|
6934
|
+
} else {
|
|
6935
|
+
pipelines[fa_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
|
|
6936
|
+
}
|
|
6937
|
+
|
|
6471
6938
|
assert(pipeline);
|
|
6472
6939
|
|
|
6473
6940
|
uint32_t split_kv = KV;
|
|
@@ -6483,7 +6950,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6483
6950
|
if (split_k > 1) {
|
|
6484
6951
|
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
|
6485
6952
|
// of "align", so recompute split_k based on that.
|
|
6486
|
-
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k),
|
|
6953
|
+
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
|
|
6487
6954
|
split_k = CEIL_DIV(KV, split_kv);
|
|
6488
6955
|
workgroups_x = split_k;
|
|
6489
6956
|
}
|
|
@@ -6525,10 +6992,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6525
6992
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
6526
6993
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
6527
6994
|
|
|
6528
|
-
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
|
|
6529
|
-
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
|
|
6995
|
+
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr;
|
|
6996
|
+
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0;
|
|
6530
6997
|
|
|
6531
|
-
bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
|
|
6998
|
+
bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false;
|
|
6532
6999
|
|
|
6533
7000
|
if (ctx->device->uma) {
|
|
6534
7001
|
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
|
|
@@ -6543,6 +7010,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6543
7010
|
ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
|
|
6544
7011
|
M_uma = d_M != nullptr;
|
|
6545
7012
|
}
|
|
7013
|
+
if (sinks) {
|
|
7014
|
+
ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset);
|
|
7015
|
+
S_uma = d_S != nullptr;
|
|
7016
|
+
}
|
|
6546
7017
|
}
|
|
6547
7018
|
|
|
6548
7019
|
|
|
@@ -6578,7 +7049,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6578
7049
|
}
|
|
6579
7050
|
}
|
|
6580
7051
|
|
|
6581
|
-
|
|
7052
|
+
if (!S_uma) {
|
|
7053
|
+
d_S = d_Q;
|
|
7054
|
+
s_buf_offset = q_buf_offset;
|
|
7055
|
+
if (sinks) {
|
|
7056
|
+
ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context;
|
|
7057
|
+
d_S = s_buf_ctx->dev_buffer;
|
|
7058
|
+
s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs;
|
|
7059
|
+
}
|
|
7060
|
+
}
|
|
7061
|
+
|
|
7062
|
+
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
|
|
6582
7063
|
|
|
6583
7064
|
const vk_flash_attn_push_constants pc = { N, KV,
|
|
6584
7065
|
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
|
@@ -6593,15 +7074,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6593
7074
|
mask_n_head_log2, m0, m1,
|
|
6594
7075
|
gqa_ratio, split_kv, split_k };
|
|
6595
7076
|
|
|
6596
|
-
ggml_vk_sync_buffers(subctx);
|
|
6597
|
-
|
|
6598
7077
|
if (split_k > 1) {
|
|
7078
|
+
if (ctx->prealloc_split_k_need_sync) {
|
|
7079
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
7080
|
+
}
|
|
7081
|
+
|
|
6599
7082
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
6600
7083
|
{
|
|
6601
7084
|
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
|
6602
7085
|
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
|
6603
7086
|
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
|
6604
7087
|
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
|
7088
|
+
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
|
|
6605
7089
|
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
|
6606
7090
|
},
|
|
6607
7091
|
// We only use split_k when group query attention is enabled, which means
|
|
@@ -6610,14 +7094,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6610
7094
|
// cancel out the divide by wg_denoms[0].
|
|
6611
7095
|
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
|
6612
7096
|
|
|
6613
|
-
ggml_vk_sync_buffers(subctx);
|
|
6614
|
-
const std::array<uint32_t,
|
|
7097
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
7098
|
+
const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
|
|
6615
7099
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
|
6616
7100
|
{
|
|
6617
7101
|
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
|
7102
|
+
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
|
|
6618
7103
|
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
|
6619
7104
|
},
|
|
6620
7105
|
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
|
7106
|
+
ctx->prealloc_split_k_need_sync = true;
|
|
6621
7107
|
} else {
|
|
6622
7108
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
6623
7109
|
{
|
|
@@ -6625,13 +7111,42 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6625
7111
|
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
|
6626
7112
|
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
|
6627
7113
|
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
|
7114
|
+
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
|
|
6628
7115
|
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
|
6629
7116
|
},
|
|
6630
7117
|
pc, { workgroups_x, workgroups_y, workgroups_z });
|
|
6631
7118
|
}
|
|
6632
7119
|
}
|
|
6633
7120
|
|
|
6634
|
-
static
|
|
7121
|
+
static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst) {
|
|
7122
|
+
const ggml_tensor *src0 = dst->src[0];
|
|
7123
|
+
const ggml_tensor *src1 = dst->src[1];
|
|
7124
|
+
|
|
7125
|
+
// src0 - kernel: [KW, KH, Cin, Cout]
|
|
7126
|
+
// src1 - input: [W, H, Cin, N]
|
|
7127
|
+
// dst - result: [OW, OH, Cout, N]
|
|
7128
|
+
|
|
7129
|
+
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
|
|
7130
|
+
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
|
|
7131
|
+
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
|
7132
|
+
};
|
|
7133
|
+
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
|
|
7134
|
+
int64_t W = src1->ne[0];
|
|
7135
|
+
int64_t H = src1->ne[1];
|
|
7136
|
+
int64_t KW = src0->ne[0];
|
|
7137
|
+
int64_t KH = src0->ne[1];
|
|
7138
|
+
int64_t Cout = src0->ne[3];
|
|
7139
|
+
int64_t N = src1->ne[3];
|
|
7140
|
+
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
|
|
7141
|
+
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
|
|
7142
|
+
int64_t NPQ = N * OW * OH;
|
|
7143
|
+
|
|
7144
|
+
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
|
|
7145
|
+
std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
|
|
7146
|
+
return elements;
|
|
7147
|
+
}
|
|
7148
|
+
|
|
7149
|
+
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
|
|
6635
7150
|
switch (op) {
|
|
6636
7151
|
case GGML_OP_GET_ROWS:
|
|
6637
7152
|
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
|
@@ -6659,8 +7174,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6659
7174
|
switch (op) {
|
|
6660
7175
|
case GGML_OP_ADD:
|
|
6661
7176
|
{
|
|
6662
|
-
|
|
6663
|
-
|
|
7177
|
+
if (ctx->num_additional_fused_ops > 0) {
|
|
7178
|
+
if (ctx->do_add_rms_partials) {
|
|
7179
|
+
return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];
|
|
7180
|
+
} else {
|
|
7181
|
+
return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
|
|
7182
|
+
}
|
|
7183
|
+
}
|
|
7184
|
+
if (ctx->do_add_rms_partials) {
|
|
7185
|
+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
|
|
7186
|
+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
|
|
7187
|
+
} else {
|
|
7188
|
+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
|
|
7189
|
+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
|
|
7190
|
+
}
|
|
6664
7191
|
}
|
|
6665
7192
|
case GGML_OP_SUB:
|
|
6666
7193
|
{
|
|
@@ -6681,6 +7208,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6681
7208
|
break;
|
|
6682
7209
|
}
|
|
6683
7210
|
return nullptr;
|
|
7211
|
+
case GGML_OP_ADD_ID:
|
|
7212
|
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
|
|
7213
|
+
return ctx->device->pipeline_add_id_f32;
|
|
7214
|
+
}
|
|
7215
|
+
return nullptr;
|
|
6684
7216
|
case GGML_OP_CONCAT:
|
|
6685
7217
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6686
7218
|
return ctx->device->pipeline_concat_f32;
|
|
@@ -6715,6 +7247,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6715
7247
|
return ctx->device->pipeline_sqr_f32;
|
|
6716
7248
|
}
|
|
6717
7249
|
return nullptr;
|
|
7250
|
+
case GGML_OP_SQRT:
|
|
7251
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
7252
|
+
return ctx->device->pipeline_sqrt_f32;
|
|
7253
|
+
}
|
|
7254
|
+
return nullptr;
|
|
6718
7255
|
case GGML_OP_SIN:
|
|
6719
7256
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6720
7257
|
return ctx->device->pipeline_sin_f32;
|
|
@@ -6773,7 +7310,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6773
7310
|
return nullptr;
|
|
6774
7311
|
case GGML_OP_RMS_NORM:
|
|
6775
7312
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6776
|
-
|
|
7313
|
+
if (ctx->do_add_rms_partials) {
|
|
7314
|
+
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
|
|
7315
|
+
} else {
|
|
7316
|
+
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
|
|
7317
|
+
}
|
|
6777
7318
|
}
|
|
6778
7319
|
return nullptr;
|
|
6779
7320
|
case GGML_OP_RMS_NORM_BACK:
|
|
@@ -6794,6 +7335,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6794
7335
|
}
|
|
6795
7336
|
|
|
6796
7337
|
switch (ggml_get_unary_op(dst)) {
|
|
7338
|
+
case GGML_UNARY_OP_EXP:
|
|
7339
|
+
return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
|
|
6797
7340
|
case GGML_UNARY_OP_SILU:
|
|
6798
7341
|
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
|
|
6799
7342
|
case GGML_UNARY_OP_GELU:
|
|
@@ -6826,6 +7369,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6826
7369
|
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
|
6827
7370
|
case GGML_GLU_OP_SWIGLU:
|
|
6828
7371
|
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
|
7372
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
7373
|
+
return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
|
|
6829
7374
|
case GGML_GLU_OP_GEGLU_ERF:
|
|
6830
7375
|
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
|
|
6831
7376
|
case GGML_GLU_OP_GEGLU_QUICK:
|
|
@@ -6841,6 +7386,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6841
7386
|
return nullptr;
|
|
6842
7387
|
case GGML_OP_SOFT_MAX:
|
|
6843
7388
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
|
7389
|
+
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
|
|
6844
7390
|
|
|
6845
7391
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
|
6846
7392
|
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
|
@@ -6895,11 +7441,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6895
7441
|
}
|
|
6896
7442
|
case GGML_OP_ARGSORT:
|
|
6897
7443
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
|
6898
|
-
|
|
7444
|
+
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
|
7445
|
+
return ctx->device->pipeline_argsort_f32[idx];
|
|
6899
7446
|
}
|
|
6900
7447
|
return nullptr;
|
|
6901
7448
|
case GGML_OP_SUM:
|
|
6902
7449
|
case GGML_OP_SUM_ROWS:
|
|
7450
|
+
case GGML_OP_MEAN:
|
|
6903
7451
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6904
7452
|
return ctx->device->pipeline_sum_rows_f32;
|
|
6905
7453
|
}
|
|
@@ -6952,15 +7500,44 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6952
7500
|
return ctx->device->pipeline_opt_step_adamw_f32;
|
|
6953
7501
|
}
|
|
6954
7502
|
return nullptr;
|
|
7503
|
+
case GGML_OP_OPT_STEP_SGD:
|
|
7504
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
7505
|
+
return ctx->device->pipeline_opt_step_sgd_f32;
|
|
7506
|
+
}
|
|
7507
|
+
return nullptr;
|
|
6955
7508
|
case GGML_OP_LEAKY_RELU:
|
|
6956
7509
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6957
7510
|
return ctx->device->pipeline_leaky_relu_f32;
|
|
6958
7511
|
}
|
|
6959
7512
|
return nullptr;
|
|
6960
7513
|
case GGML_OP_CONV_2D:
|
|
6961
|
-
if (
|
|
7514
|
+
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
|
6962
7515
|
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
|
6963
|
-
|
|
7516
|
+
auto elements = ggml_vk_get_conv_elements(dst);
|
|
7517
|
+
vk_conv_shapes shape;
|
|
7518
|
+
|
|
7519
|
+
uint32_t tiles[CONV_SHAPE_COUNT];
|
|
7520
|
+
for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) {
|
|
7521
|
+
tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]);
|
|
7522
|
+
}
|
|
7523
|
+
|
|
7524
|
+
// We can't query number of shader cores on Intel, use 32 as a placeholder
|
|
7525
|
+
// so small convolutions will still choose a smaller tile.
|
|
7526
|
+
const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;
|
|
7527
|
+
|
|
7528
|
+
if (elements[0] > 64 && tiles[CONV_SHAPE_128x128] >= shader_core_count * 2) {
|
|
7529
|
+
shape = CONV_SHAPE_128x128;
|
|
7530
|
+
} else if (elements[0] <= 32 && tiles[CONV_SHAPE_32x256] >= shader_core_count * 2) {
|
|
7531
|
+
shape = CONV_SHAPE_32x256;
|
|
7532
|
+
} else {
|
|
7533
|
+
shape = CONV_SHAPE_64x32;
|
|
7534
|
+
}
|
|
7535
|
+
|
|
7536
|
+
if (src0->type == GGML_TYPE_F32) {
|
|
7537
|
+
return ctx->device->pipeline_conv2d_f32[shape];
|
|
7538
|
+
} else if (src0->type == GGML_TYPE_F16) {
|
|
7539
|
+
return ctx->device->pipeline_conv2d_f16_f32[shape];
|
|
7540
|
+
}
|
|
6964
7541
|
}
|
|
6965
7542
|
return nullptr;
|
|
6966
7543
|
case GGML_OP_CONV_2D_DW:
|
|
@@ -6970,6 +7547,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6970
7547
|
} else if (ggml_is_contiguous_channels(src1)) {
|
|
6971
7548
|
return ctx->device->pipeline_conv2d_dw_cwhn_f32;
|
|
6972
7549
|
}
|
|
7550
|
+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
|
7551
|
+
if (ggml_is_contiguous(src1)) {
|
|
7552
|
+
return ctx->device->pipeline_conv2d_dw_whcn_f16_f32;
|
|
7553
|
+
} else if (ggml_is_contiguous_channels(src1)) {
|
|
7554
|
+
return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32;
|
|
7555
|
+
}
|
|
6973
7556
|
}
|
|
6974
7557
|
return nullptr;
|
|
6975
7558
|
default:
|
|
@@ -6987,9 +7570,11 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
|
6987
7570
|
case GGML_OP_SUB:
|
|
6988
7571
|
case GGML_OP_MUL:
|
|
6989
7572
|
case GGML_OP_DIV:
|
|
7573
|
+
case GGML_OP_ADD_ID:
|
|
6990
7574
|
case GGML_OP_CONCAT:
|
|
6991
7575
|
case GGML_OP_UPSCALE:
|
|
6992
7576
|
case GGML_OP_SQR:
|
|
7577
|
+
case GGML_OP_SQRT:
|
|
6993
7578
|
case GGML_OP_SIN:
|
|
6994
7579
|
case GGML_OP_COS:
|
|
6995
7580
|
case GGML_OP_CLAMP:
|
|
@@ -7001,6 +7586,9 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
|
7001
7586
|
case GGML_OP_CONV_2D_DW:
|
|
7002
7587
|
case GGML_OP_IM2COL:
|
|
7003
7588
|
case GGML_OP_SET_ROWS:
|
|
7589
|
+
case GGML_OP_SUM:
|
|
7590
|
+
case GGML_OP_SUM_ROWS:
|
|
7591
|
+
case GGML_OP_MEAN:
|
|
7004
7592
|
return true;
|
|
7005
7593
|
default:
|
|
7006
7594
|
return false;
|
|
@@ -7035,6 +7623,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
|
|
|
7035
7623
|
GGML_UNUSED(src2);
|
|
7036
7624
|
}
|
|
7037
7625
|
|
|
7626
|
+
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
|
7627
|
+
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
|
7628
|
+
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
|
7629
|
+
|
|
7630
|
+
p.misalign_offsets = (a_offset << 16) | d_offset;
|
|
7631
|
+
|
|
7632
|
+
GGML_UNUSED(src1);
|
|
7633
|
+
GGML_UNUSED(src2);
|
|
7634
|
+
}
|
|
7635
|
+
|
|
7038
7636
|
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
|
7039
7637
|
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
|
7040
7638
|
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
|
|
@@ -7185,10 +7783,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
7185
7783
|
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
|
7186
7784
|
|
|
7187
7785
|
if (op_supports_incontiguous) {
|
|
7188
|
-
x_sz = ggml_nbytes(src0);
|
|
7189
|
-
y_sz = use_src1 ? ggml_nbytes(src1) : 0;
|
|
7190
|
-
z_sz = use_src2 ? ggml_nbytes(src2) : 0;
|
|
7191
|
-
d_sz = ggml_nbytes(dst);
|
|
7786
|
+
x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
|
|
7787
|
+
y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
|
|
7788
|
+
z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
|
|
7789
|
+
d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
|
|
7192
7790
|
|
|
7193
7791
|
if (x_buf_offset + x_sz >= d_X->size) {
|
|
7194
7792
|
x_sz = VK_WHOLE_SIZE;
|
|
@@ -7216,6 +7814,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
7216
7814
|
case GGML_OP_SOFT_MAX:
|
|
7217
7815
|
case GGML_OP_SOFT_MAX_BACK:
|
|
7218
7816
|
case GGML_OP_SUM_ROWS:
|
|
7817
|
+
case GGML_OP_MEAN:
|
|
7219
7818
|
case GGML_OP_ARGMAX:
|
|
7220
7819
|
{
|
|
7221
7820
|
const uint32_t nr = ggml_nrows(src0);
|
|
@@ -7228,7 +7827,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
7228
7827
|
}
|
|
7229
7828
|
} break;
|
|
7230
7829
|
case GGML_OP_RMS_NORM:
|
|
7231
|
-
|
|
7830
|
+
if (ctx->do_add_rms_partials) {
|
|
7831
|
+
// Run one element per thread, 128 threads per workgroup
|
|
7832
|
+
elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
|
|
7833
|
+
} else {
|
|
7834
|
+
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
|
|
7835
|
+
}
|
|
7232
7836
|
break;
|
|
7233
7837
|
|
|
7234
7838
|
case GGML_OP_SUM:
|
|
@@ -7287,35 +7891,15 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
7287
7891
|
} break;
|
|
7288
7892
|
case GGML_OP_CONV_2D:
|
|
7289
7893
|
{
|
|
7290
|
-
|
|
7291
|
-
|
|
7292
|
-
// dst - result: [OW, OH, Cout, N]
|
|
7293
|
-
|
|
7294
|
-
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
|
|
7295
|
-
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
|
|
7296
|
-
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
|
7297
|
-
};
|
|
7298
|
-
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
|
|
7299
|
-
int64_t W = src1->ne[0];
|
|
7300
|
-
int64_t H = src1->ne[1];
|
|
7301
|
-
int64_t KW = src0->ne[0];
|
|
7302
|
-
int64_t KH = src0->ne[1];
|
|
7303
|
-
int64_t Cout = src0->ne[3];
|
|
7304
|
-
int64_t N = src1->ne[3];
|
|
7305
|
-
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
|
|
7306
|
-
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
|
|
7307
|
-
int64_t NPQ = N * OW * OH;
|
|
7308
|
-
|
|
7309
|
-
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
|
|
7310
|
-
elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
|
|
7311
|
-
}
|
|
7312
|
-
break;
|
|
7894
|
+
elements = ggml_vk_get_conv_elements(dst);
|
|
7895
|
+
} break;
|
|
7313
7896
|
case GGML_OP_ADD:
|
|
7314
7897
|
case GGML_OP_SUB:
|
|
7315
7898
|
case GGML_OP_DIV:
|
|
7316
7899
|
case GGML_OP_MUL:
|
|
7317
7900
|
case GGML_OP_SCALE:
|
|
7318
7901
|
case GGML_OP_SQR:
|
|
7902
|
+
case GGML_OP_SQRT:
|
|
7319
7903
|
case GGML_OP_SIN:
|
|
7320
7904
|
case GGML_OP_COS:
|
|
7321
7905
|
case GGML_OP_CLAMP:
|
|
@@ -7354,6 +7938,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
7354
7938
|
elements = { ne, 1, 1 };
|
|
7355
7939
|
}
|
|
7356
7940
|
} break;
|
|
7941
|
+
case GGML_OP_ADD_ID:
|
|
7942
|
+
{
|
|
7943
|
+
elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
|
|
7944
|
+
} break;
|
|
7357
7945
|
case GGML_OP_SET_ROWS:
|
|
7358
7946
|
{
|
|
7359
7947
|
uint32_t ne = ggml_nelements(src0);
|
|
@@ -7393,8 +7981,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
7393
7981
|
}
|
|
7394
7982
|
}
|
|
7395
7983
|
|
|
7396
|
-
if (op ==
|
|
7397
|
-
|
|
7984
|
+
if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
|
|
7985
|
+
vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X;
|
|
7986
|
+
size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0;
|
|
7987
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
7988
|
+
{ vk_subbuffer{ d_X, x_buf_offset, x_sz },
|
|
7989
|
+
vk_subbuffer{ d_Y, y_buf_offset, y_sz },
|
|
7990
|
+
vk_subbuffer{ d_D, d_buf_offset, d_sz },
|
|
7991
|
+
vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
|
|
7992
|
+
}, pc, elements);
|
|
7993
|
+
} else if (op == GGML_OP_GLU) {
|
|
7994
|
+
// Empty src1 is possible in glu, but the shader needs a buffer
|
|
7398
7995
|
vk_subbuffer subbuf_y;
|
|
7399
7996
|
if (use_src1) {
|
|
7400
7997
|
subbuf_y = { d_Y, y_buf_offset, y_sz };
|
|
@@ -7402,8 +7999,24 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
7402
7999
|
subbuf_y = { d_X, 0, x_sz };
|
|
7403
8000
|
}
|
|
7404
8001
|
|
|
7405
|
-
ggml_vk_sync_buffers(subctx);
|
|
7406
8002
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
8003
|
+
} else if (op == GGML_OP_SOFT_MAX) {
|
|
8004
|
+
// Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
|
|
8005
|
+
vk_subbuffer subbuf_y;
|
|
8006
|
+
if (use_src1) {
|
|
8007
|
+
subbuf_y = { d_Y, y_buf_offset, y_sz };
|
|
8008
|
+
} else {
|
|
8009
|
+
subbuf_y = { d_X, 0, x_sz };
|
|
8010
|
+
}
|
|
8011
|
+
|
|
8012
|
+
vk_subbuffer subbuf_z;
|
|
8013
|
+
if (use_src2) {
|
|
8014
|
+
subbuf_z = { d_Z, z_buf_offset, z_sz };
|
|
8015
|
+
} else {
|
|
8016
|
+
subbuf_z = { d_X, 0, x_sz };
|
|
8017
|
+
}
|
|
8018
|
+
|
|
8019
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
7407
8020
|
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
|
|
7408
8021
|
// Empty src2 is possible in rope, but the shader needs a buffer
|
|
7409
8022
|
vk_subbuffer subbuf_z;
|
|
@@ -7413,26 +8026,23 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
7413
8026
|
subbuf_z = { d_X, 0, x_sz };
|
|
7414
8027
|
}
|
|
7415
8028
|
|
|
7416
|
-
ggml_vk_sync_buffers(subctx);
|
|
7417
8029
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
7418
8030
|
} else if (op == GGML_OP_IM2COL) {
|
|
7419
8031
|
// im2col uses only src1 and dst buffers
|
|
7420
|
-
ggml_vk_sync_buffers(subctx);
|
|
7421
8032
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
7422
8033
|
} else if (op == GGML_OP_COUNT_EQUAL) {
|
|
7423
|
-
ggml_vk_sync_buffers(subctx);
|
|
7424
8034
|
// count_equal assumes that destination buffer is initialized with zeroes
|
|
7425
8035
|
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
|
|
7426
|
-
ggml_vk_sync_buffers(subctx);
|
|
8036
|
+
ggml_vk_sync_buffers(ctx, subctx);
|
|
7427
8037
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
8038
|
+
} else if (op == GGML_OP_OPT_STEP_SGD) {
|
|
8039
|
+
// OPT_STEP_SGD works on src0, it does not need dst
|
|
8040
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
|
|
7428
8041
|
} else if (use_src2) {
|
|
7429
|
-
ggml_vk_sync_buffers(subctx);
|
|
7430
8042
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
7431
8043
|
} else if (use_src1) {
|
|
7432
|
-
ggml_vk_sync_buffers(subctx);
|
|
7433
8044
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
7434
8045
|
} else {
|
|
7435
|
-
ggml_vk_sync_buffers(subctx);
|
|
7436
8046
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
|
7437
8047
|
}
|
|
7438
8048
|
}
|
|
@@ -7472,6 +8082,116 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
7472
8082
|
}, dryrun);
|
|
7473
8083
|
}
|
|
7474
8084
|
|
|
8085
|
+
static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
|
|
8086
|
+
const ggml_tensor *first_node = cgraph->nodes[node_idx];
|
|
8087
|
+
const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
|
|
8088
|
+
|
|
8089
|
+
// Make a list of all the tensors used by the op.
|
|
8090
|
+
// Last element of the list is the dest tensor.
|
|
8091
|
+
const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
|
|
8092
|
+
uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
|
|
8093
|
+
uint32_t num_tensors = num_srcs + 1;
|
|
8094
|
+
GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT);
|
|
8095
|
+
|
|
8096
|
+
tensors[0] = first_node->src[0];
|
|
8097
|
+
tensors[1] = first_node->src[1];
|
|
8098
|
+
for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) {
|
|
8099
|
+
// check whether the previous result is src[0] or src[1]
|
|
8100
|
+
if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) {
|
|
8101
|
+
tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1];
|
|
8102
|
+
} else {
|
|
8103
|
+
tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0];
|
|
8104
|
+
}
|
|
8105
|
+
}
|
|
8106
|
+
tensors[num_srcs] = dst;
|
|
8107
|
+
|
|
8108
|
+
vk_op_multi_add_push_constants pc;
|
|
8109
|
+
pc.ne20 = (uint32_t)dst->ne[0];
|
|
8110
|
+
pc.ne21 = (uint32_t)dst->ne[1];
|
|
8111
|
+
pc.ne22 = (uint32_t)dst->ne[2];
|
|
8112
|
+
pc.ne23 = (uint32_t)dst->ne[3];
|
|
8113
|
+
|
|
8114
|
+
for (uint32_t i = 0; i < num_tensors; ++i) {
|
|
8115
|
+
const ggml_tensor *t = tensors[i];
|
|
8116
|
+
pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float);
|
|
8117
|
+
pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float);
|
|
8118
|
+
pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
|
|
8119
|
+
pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
|
|
8120
|
+
}
|
|
8121
|
+
pc.rms_partials = ctx->do_add_rms_partials;
|
|
8122
|
+
|
|
8123
|
+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op);
|
|
8124
|
+
|
|
8125
|
+
if (pipeline == nullptr) {
|
|
8126
|
+
std::cerr << "ggml_vulkan: Error: Missing multi_add";
|
|
8127
|
+
GGML_ABORT("fatal error");
|
|
8128
|
+
}
|
|
8129
|
+
|
|
8130
|
+
if (dryrun) {
|
|
8131
|
+
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
8132
|
+
return;
|
|
8133
|
+
}
|
|
8134
|
+
|
|
8135
|
+
ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT];
|
|
8136
|
+
vk_buffer buf[MAX_PARAMETER_COUNT];
|
|
8137
|
+
size_t offset[MAX_PARAMETER_COUNT];
|
|
8138
|
+
bool uma[MAX_PARAMETER_COUNT];
|
|
8139
|
+
|
|
8140
|
+
for (uint32_t i = 0; i < num_tensors; ++i) {
|
|
8141
|
+
buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
|
|
8142
|
+
buf[i] = nullptr;
|
|
8143
|
+
offset[i] = 0;
|
|
8144
|
+
uma[i] = false;
|
|
8145
|
+
|
|
8146
|
+
if (ctx->device->uma) {
|
|
8147
|
+
ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
|
|
8148
|
+
uma[i] = buf[i] != nullptr;
|
|
8149
|
+
}
|
|
8150
|
+
if (!uma[i]) {
|
|
8151
|
+
buf[i] = buf_ctx[i]->dev_buffer;
|
|
8152
|
+
offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
|
|
8153
|
+
}
|
|
8154
|
+
GGML_ASSERT(buf[i] != nullptr);
|
|
8155
|
+
}
|
|
8156
|
+
// If any remaining descriptors are unused, just point them at src[0]
|
|
8157
|
+
for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) {
|
|
8158
|
+
buf[i] = buf[0];
|
|
8159
|
+
offset[i] = 0;
|
|
8160
|
+
}
|
|
8161
|
+
if (ctx->do_add_rms_partials) {
|
|
8162
|
+
buf[num_tensors] = ctx->prealloc_add_rms_partials;
|
|
8163
|
+
offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset;
|
|
8164
|
+
}
|
|
8165
|
+
|
|
8166
|
+
std::array<uint32_t, 3> elements;
|
|
8167
|
+
|
|
8168
|
+
uint32_t ne = ggml_nelements(dst);
|
|
8169
|
+
if (ne > 262144) {
|
|
8170
|
+
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
|
8171
|
+
} else if (ne > 512) {
|
|
8172
|
+
elements = { 512, CEIL_DIV(ne, 512), 1 };
|
|
8173
|
+
} else {
|
|
8174
|
+
elements = { ne, 1, 1 };
|
|
8175
|
+
}
|
|
8176
|
+
|
|
8177
|
+
static_assert(MAX_PARAMETER_COUNT == 12);
|
|
8178
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
8179
|
+
{
|
|
8180
|
+
vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
|
|
8181
|
+
vk_subbuffer{ buf[1], offset[1], VK_WHOLE_SIZE },
|
|
8182
|
+
vk_subbuffer{ buf[2], offset[2], VK_WHOLE_SIZE },
|
|
8183
|
+
vk_subbuffer{ buf[3], offset[3], VK_WHOLE_SIZE },
|
|
8184
|
+
vk_subbuffer{ buf[4], offset[4], VK_WHOLE_SIZE },
|
|
8185
|
+
vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
|
|
8186
|
+
vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
|
|
8187
|
+
vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
|
|
8188
|
+
vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE },
|
|
8189
|
+
vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE },
|
|
8190
|
+
vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE },
|
|
8191
|
+
vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE },
|
|
8192
|
+
}, pc, elements);
|
|
8193
|
+
}
|
|
8194
|
+
|
|
7475
8195
|
static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
7476
8196
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7477
8197
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
@@ -7483,7 +8203,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
7483
8203
|
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
|
7484
8204
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7485
8205
|
0,
|
|
7486
|
-
0.0f, 0.0f,
|
|
8206
|
+
0.0f, 0.0f, ctx->do_add_rms_partials,
|
|
7487
8207
|
}, dryrun);
|
|
7488
8208
|
}
|
|
7489
8209
|
|
|
@@ -7532,6 +8252,21 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
7532
8252
|
}, dryrun);
|
|
7533
8253
|
}
|
|
7534
8254
|
|
|
8255
|
+
static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
|
8256
|
+
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
8257
|
+
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
8258
|
+
const uint32_t src2_type_size = ggml_type_size(src2->type);
|
|
8259
|
+
|
|
8260
|
+
ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, {
|
|
8261
|
+
(uint32_t)dst->ne[0],
|
|
8262
|
+
(uint32_t)dst->ne[1],
|
|
8263
|
+
(uint32_t)src0->nb[1] / src0_type_size,
|
|
8264
|
+
(uint32_t)src0->nb[2] / src0_type_size,
|
|
8265
|
+
(uint32_t)src1->nb[1] / src1_type_size,
|
|
8266
|
+
(uint32_t)src2->nb[1] / src2_type_size,
|
|
8267
|
+
}, dryrun);
|
|
8268
|
+
}
|
|
8269
|
+
|
|
7535
8270
|
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
|
|
7536
8271
|
GGML_ASSERT(version == 6 || version == 7);
|
|
7537
8272
|
int num_srcs = version == 6 ? 6 : 7;
|
|
@@ -7556,8 +8291,6 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
7556
8291
|
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
|
|
7557
8292
|
}
|
|
7558
8293
|
|
|
7559
|
-
ggml_vk_sync_buffers(subctx);
|
|
7560
|
-
|
|
7561
8294
|
vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
|
7562
8295
|
size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
|
7563
8296
|
bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
|
|
@@ -7695,8 +8428,6 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
|
|
|
7695
8428
|
ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
|
|
7696
8429
|
ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
|
|
7697
8430
|
|
|
7698
|
-
ggml_vk_sync_buffers(subctx);
|
|
7699
|
-
|
|
7700
8431
|
vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
|
|
7701
8432
|
size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
|
|
7702
8433
|
bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
|
|
@@ -7763,6 +8494,12 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
|
|
|
7763
8494
|
);
|
|
7764
8495
|
}
|
|
7765
8496
|
|
|
8497
|
+
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
|
8498
|
+
const size_t n = ggml_nelements(dst->src[0]);
|
|
8499
|
+
|
|
8500
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
|
|
8501
|
+
}
|
|
8502
|
+
|
|
7766
8503
|
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
7767
8504
|
int * op_params = (int *)dst->op_params;
|
|
7768
8505
|
|
|
@@ -7815,6 +8552,10 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
7815
8552
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
|
7816
8553
|
}
|
|
7817
8554
|
|
|
8555
|
+
static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
8556
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
|
8557
|
+
}
|
|
8558
|
+
|
|
7818
8559
|
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7819
8560
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
|
7820
8561
|
}
|
|
@@ -7882,6 +8623,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
7882
8623
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
7883
8624
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7884
8625
|
|
|
8626
|
+
// Skip empty skip_rows operations. For most ops the empty check at the start
|
|
8627
|
+
// of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst
|
|
8628
|
+
// with empty srcs.
|
|
8629
|
+
if (ggml_is_empty(src0) || ggml_is_empty(src1)) {
|
|
8630
|
+
return;
|
|
8631
|
+
}
|
|
8632
|
+
|
|
7885
8633
|
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
|
|
7886
8634
|
(uint32_t)ggml_nelements(src0),
|
|
7887
8635
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
@@ -7913,19 +8661,39 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
7913
8661
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
|
|
7914
8662
|
}
|
|
7915
8663
|
|
|
8664
|
+
static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
|
|
8665
|
+
const uint32_t ne = (uint32_t)node->ne[0];
|
|
8666
|
+
const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
|
|
8667
|
+
const uint32_t num_partials = CEIL_DIV(ne, denom);
|
|
8668
|
+
return num_partials;
|
|
8669
|
+
}
|
|
8670
|
+
|
|
8671
|
+
static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
|
|
8672
|
+
const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
|
|
8673
|
+
const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
|
|
8674
|
+
return num_bytes;
|
|
8675
|
+
}
|
|
8676
|
+
|
|
7916
8677
|
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
|
|
7917
8678
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7918
8679
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
7919
8680
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7920
8681
|
|
|
8682
|
+
uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
|
|
8683
|
+
|
|
7921
8684
|
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
|
|
7922
8685
|
(uint32_t)ggml_nelements(src0),
|
|
7923
8686
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7924
8687
|
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
|
7925
8688
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7926
8689
|
0,
|
|
7927
|
-
op_params[0], 0.0f,
|
|
8690
|
+
op_params[0], 0.0f, (int32_t)param3,
|
|
7928
8691
|
}, dryrun);
|
|
8692
|
+
|
|
8693
|
+
if (ctx->do_add_rms_partials) {
|
|
8694
|
+
ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
|
|
8695
|
+
ctx->do_add_rms_partials = false;
|
|
8696
|
+
}
|
|
7929
8697
|
}
|
|
7930
8698
|
|
|
7931
8699
|
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
@@ -7943,8 +8711,12 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
|
|
|
7943
8711
|
}
|
|
7944
8712
|
|
|
7945
8713
|
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
8714
|
+
const float * op_params_f = (const float *)dst->op_params;
|
|
8715
|
+
|
|
7946
8716
|
const bool swapped = (bool)dst->op_params[1];
|
|
7947
8717
|
const bool split = src1 != nullptr;
|
|
8718
|
+
const float alpha = op_params_f[2];
|
|
8719
|
+
const float limit = op_params_f[3];
|
|
7948
8720
|
|
|
7949
8721
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
7950
8722
|
|
|
@@ -7958,7 +8730,15 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
7958
8730
|
|
|
7959
8731
|
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
|
|
7960
8732
|
|
|
7961
|
-
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
|
|
8733
|
+
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
|
|
8734
|
+
{
|
|
8735
|
+
(uint32_t)ggml_nelements(dst),
|
|
8736
|
+
(uint32_t)src0->ne[0],
|
|
8737
|
+
(uint32_t)dst->ne[0],
|
|
8738
|
+
mode,
|
|
8739
|
+
alpha,
|
|
8740
|
+
limit
|
|
8741
|
+
}, dryrun);
|
|
7962
8742
|
}
|
|
7963
8743
|
|
|
7964
8744
|
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
@@ -7966,7 +8746,7 @@ static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
7966
8746
|
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
|
|
7967
8747
|
}
|
|
7968
8748
|
|
|
7969
|
-
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
8749
|
+
static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
|
7970
8750
|
float * op_params = (float *)dst->op_params;
|
|
7971
8751
|
|
|
7972
8752
|
float scale = op_params[0];
|
|
@@ -7988,7 +8768,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
7988
8768
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
7989
8769
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
7990
8770
|
|
|
7991
|
-
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1,
|
|
8771
|
+
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
|
|
7992
8772
|
ncols,
|
|
7993
8773
|
src1 != nullptr ? nrows_y : (uint32_t)0,
|
|
7994
8774
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
|
@@ -7998,6 +8778,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
7998
8778
|
m0, m1,
|
|
7999
8779
|
n_head_log2,
|
|
8000
8780
|
nrows_x,
|
|
8781
|
+
src2 != nullptr
|
|
8001
8782
|
}, dryrun);
|
|
8002
8783
|
}
|
|
8003
8784
|
|
|
@@ -8034,7 +8815,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|
|
8034
8815
|
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
|
8035
8816
|
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
|
8036
8817
|
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
|
|
8037
|
-
sections[0], sections[1], sections[2], sections[3], backprop
|
|
8818
|
+
{ sections[0], sections[1], sections[2], sections[3] }, backprop
|
|
8038
8819
|
}, dryrun);
|
|
8039
8820
|
}
|
|
8040
8821
|
|
|
@@ -8043,30 +8824,30 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
8043
8824
|
|
|
8044
8825
|
uint32_t ncols = src0->ne[0];
|
|
8045
8826
|
|
|
8046
|
-
uint32_t ncols_pad = 1;
|
|
8047
|
-
while (ncols_pad < ncols) {
|
|
8048
|
-
ncols_pad *= 2;
|
|
8049
|
-
}
|
|
8050
|
-
|
|
8051
|
-
GGML_ASSERT(ncols_pad <= 1024);
|
|
8052
|
-
|
|
8053
8827
|
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
|
|
8054
8828
|
ncols,
|
|
8055
|
-
ncols_pad,
|
|
8056
8829
|
op_params[0],
|
|
8057
8830
|
}, dryrun);
|
|
8058
8831
|
}
|
|
8059
8832
|
|
|
8060
8833
|
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
8061
|
-
|
|
8834
|
+
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
|
|
8835
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun);
|
|
8062
8836
|
}
|
|
8063
8837
|
|
|
8064
8838
|
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
8065
|
-
|
|
8839
|
+
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
|
|
8840
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun);
|
|
8841
|
+
}
|
|
8842
|
+
|
|
8843
|
+
static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
8844
|
+
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
|
|
8845
|
+
p.weight = 1.0f / (float)src0->ne[0];
|
|
8846
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun);
|
|
8066
8847
|
}
|
|
8067
8848
|
|
|
8068
8849
|
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
8069
|
-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0],
|
|
8850
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun);
|
|
8070
8851
|
}
|
|
8071
8852
|
|
|
8072
8853
|
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
@@ -8178,13 +8959,13 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
8178
8959
|
|
|
8179
8960
|
static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
|
|
8180
8961
|
const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
8181
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
8962
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
8182
8963
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
8183
8964
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
8184
8965
|
|
|
8185
8966
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
8186
8967
|
|
|
8187
|
-
GGML_ASSERT(nb00 == sizeof(float));
|
|
8968
|
+
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
|
|
8188
8969
|
GGML_ASSERT(nb10 == sizeof(float));
|
|
8189
8970
|
GGML_ASSERT(nb0 == sizeof(float));
|
|
8190
8971
|
|
|
@@ -9190,6 +9971,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|
|
9190
9971
|
}
|
|
9191
9972
|
ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
|
|
9192
9973
|
}
|
|
9974
|
+
if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) {
|
|
9975
|
+
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")");
|
|
9976
|
+
// Resize buffer
|
|
9977
|
+
if (ctx->prealloc_add_rms_partials != nullptr) {
|
|
9978
|
+
ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);
|
|
9979
|
+
}
|
|
9980
|
+
ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials);
|
|
9981
|
+
}
|
|
9193
9982
|
}
|
|
9194
9983
|
|
|
9195
9984
|
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
|
|
@@ -9220,6 +10009,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9220
10009
|
return false;
|
|
9221
10010
|
case GGML_OP_UNARY:
|
|
9222
10011
|
switch (ggml_get_unary_op(node)) {
|
|
10012
|
+
case GGML_UNARY_OP_EXP:
|
|
9223
10013
|
case GGML_UNARY_OP_SILU:
|
|
9224
10014
|
case GGML_UNARY_OP_GELU:
|
|
9225
10015
|
case GGML_UNARY_OP_GELU_ERF:
|
|
@@ -9237,6 +10027,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9237
10027
|
case GGML_GLU_OP_GEGLU:
|
|
9238
10028
|
case GGML_GLU_OP_REGLU:
|
|
9239
10029
|
case GGML_GLU_OP_SWIGLU:
|
|
10030
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
9240
10031
|
case GGML_GLU_OP_GEGLU_ERF:
|
|
9241
10032
|
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9242
10033
|
break;
|
|
@@ -9244,10 +10035,24 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9244
10035
|
return false;
|
|
9245
10036
|
}
|
|
9246
10037
|
break;
|
|
10038
|
+
case GGML_OP_ADD:
|
|
10039
|
+
{
|
|
10040
|
+
int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
|
|
10041
|
+
if (next_node_idx < cgraph->n_nodes &&
|
|
10042
|
+
cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
|
|
10043
|
+
cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
|
|
10044
|
+
ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
|
|
10045
|
+
ctx->device->add_rms_fusion) {
|
|
10046
|
+
if (dryrun) {
|
|
10047
|
+
ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
|
|
10048
|
+
}
|
|
10049
|
+
ctx->do_add_rms_partials = true;
|
|
10050
|
+
}
|
|
10051
|
+
} break;
|
|
9247
10052
|
case GGML_OP_REPEAT:
|
|
9248
10053
|
case GGML_OP_REPEAT_BACK:
|
|
9249
10054
|
case GGML_OP_GET_ROWS:
|
|
9250
|
-
case
|
|
10055
|
+
case GGML_OP_ADD_ID:
|
|
9251
10056
|
case GGML_OP_ACC:
|
|
9252
10057
|
case GGML_OP_SUB:
|
|
9253
10058
|
case GGML_OP_MUL:
|
|
@@ -9256,6 +10061,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9256
10061
|
case GGML_OP_UPSCALE:
|
|
9257
10062
|
case GGML_OP_SCALE:
|
|
9258
10063
|
case GGML_OP_SQR:
|
|
10064
|
+
case GGML_OP_SQRT:
|
|
9259
10065
|
case GGML_OP_SIN:
|
|
9260
10066
|
case GGML_OP_COS:
|
|
9261
10067
|
case GGML_OP_CLAMP:
|
|
@@ -9281,6 +10087,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9281
10087
|
case GGML_OP_ARGSORT:
|
|
9282
10088
|
case GGML_OP_SUM:
|
|
9283
10089
|
case GGML_OP_SUM_ROWS:
|
|
10090
|
+
case GGML_OP_MEAN:
|
|
9284
10091
|
case GGML_OP_ARGMAX:
|
|
9285
10092
|
case GGML_OP_COUNT_EQUAL:
|
|
9286
10093
|
case GGML_OP_IM2COL:
|
|
@@ -9294,11 +10101,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9294
10101
|
case GGML_OP_LEAKY_RELU:
|
|
9295
10102
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
9296
10103
|
case GGML_OP_OPT_STEP_ADAMW:
|
|
10104
|
+
case GGML_OP_OPT_STEP_SGD:
|
|
9297
10105
|
break;
|
|
9298
10106
|
default:
|
|
9299
10107
|
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
|
9300
10108
|
GGML_ABORT("fatal error");
|
|
9301
|
-
return false;
|
|
9302
10109
|
}
|
|
9303
10110
|
|
|
9304
10111
|
vk_context compute_ctx;
|
|
@@ -9325,6 +10132,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9325
10132
|
case GGML_OP_UPSCALE:
|
|
9326
10133
|
case GGML_OP_SCALE:
|
|
9327
10134
|
case GGML_OP_SQR:
|
|
10135
|
+
case GGML_OP_SQRT:
|
|
9328
10136
|
case GGML_OP_SIN:
|
|
9329
10137
|
case GGML_OP_COS:
|
|
9330
10138
|
case GGML_OP_CLAMP:
|
|
@@ -9349,6 +10157,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9349
10157
|
case GGML_OP_ARGSORT:
|
|
9350
10158
|
case GGML_OP_SUM:
|
|
9351
10159
|
case GGML_OP_SUM_ROWS:
|
|
10160
|
+
case GGML_OP_MEAN:
|
|
9352
10161
|
case GGML_OP_ARGMAX:
|
|
9353
10162
|
case GGML_OP_COUNT_EQUAL:
|
|
9354
10163
|
case GGML_OP_IM2COL:
|
|
@@ -9358,11 +10167,15 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9358
10167
|
case GGML_OP_CONV_2D:
|
|
9359
10168
|
case GGML_OP_CONV_2D_DW:
|
|
9360
10169
|
case GGML_OP_LEAKY_RELU:
|
|
10170
|
+
case GGML_OP_OPT_STEP_SGD:
|
|
9361
10171
|
{
|
|
9362
10172
|
// These operations all go through ggml_vk_op_f32, so short-circuit and
|
|
9363
10173
|
// do the only thing needed for the dryrun.
|
|
9364
10174
|
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
|
|
9365
10175
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
10176
|
+
if (node->op == GGML_OP_RMS_NORM) {
|
|
10177
|
+
ctx->do_add_rms_partials = false;
|
|
10178
|
+
}
|
|
9366
10179
|
return false;
|
|
9367
10180
|
}
|
|
9368
10181
|
default:
|
|
@@ -9370,6 +10183,80 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9370
10183
|
}
|
|
9371
10184
|
}
|
|
9372
10185
|
|
|
10186
|
+
if (!dryrun) {
|
|
10187
|
+
// This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers
|
|
10188
|
+
// to synchronize them. This handles most "normal" synchronization when computing the graph, and when
|
|
10189
|
+
// there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers
|
|
10190
|
+
// outside of this logic. When a node uses one of the prealloc buffers for something like
|
|
10191
|
+
// dequantization or split_k, additional synchronization is needed between those passes.
|
|
10192
|
+
bool need_sync = false;
|
|
10193
|
+
|
|
10194
|
+
// Check whether "node" requires synchronization. The node requires synchronization if it
|
|
10195
|
+
// overlaps in memory with another unsynchronized node and at least one of them is a write.
|
|
10196
|
+
// Destination nodes are checked against both the written/read lists. Source nodes are only
|
|
10197
|
+
// checked against the written list. Two nodes overlap in memory if they come from the same
|
|
10198
|
+
// buffer and the tensor or view ranges overlap.
|
|
10199
|
+
auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector<const ggml_tensor *> &unsynced_nodes) -> bool {
|
|
10200
|
+
if (unsynced_nodes.size() == 0) {
|
|
10201
|
+
return false;
|
|
10202
|
+
}
|
|
10203
|
+
auto n_base = vk_tensor_offset(node) + node->view_offs;
|
|
10204
|
+
auto n_size = ggml_nbytes(node);
|
|
10205
|
+
ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context;
|
|
10206
|
+
vk_buffer a_buf = a_buf_ctx->dev_buffer;
|
|
10207
|
+
for (auto &other : unsynced_nodes) {
|
|
10208
|
+
ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context;
|
|
10209
|
+
vk_buffer o_buf = o_buf_ctx->dev_buffer;
|
|
10210
|
+
if (a_buf == o_buf) {
|
|
10211
|
+
auto o_base = vk_tensor_offset(other) + other->view_offs;
|
|
10212
|
+
auto o_size = ggml_nbytes(other);
|
|
10213
|
+
|
|
10214
|
+
if ((o_base <= n_base && n_base < o_base + o_size) ||
|
|
10215
|
+
(n_base <= o_base && o_base < n_base + n_size)) {
|
|
10216
|
+
return true;
|
|
10217
|
+
}
|
|
10218
|
+
}
|
|
10219
|
+
}
|
|
10220
|
+
return false;
|
|
10221
|
+
};
|
|
10222
|
+
|
|
10223
|
+
// For all fused ops, check if the destination node or any of the source
|
|
10224
|
+
// nodes require synchronization.
|
|
10225
|
+
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) {
|
|
10226
|
+
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
|
|
10227
|
+
if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
|
|
10228
|
+
need_sync = true;
|
|
10229
|
+
break;
|
|
10230
|
+
}
|
|
10231
|
+
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
|
|
10232
|
+
if (!cur_node->src[j]) {
|
|
10233
|
+
continue;
|
|
10234
|
+
}
|
|
10235
|
+
if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) {
|
|
10236
|
+
need_sync = true;
|
|
10237
|
+
break;
|
|
10238
|
+
}
|
|
10239
|
+
}
|
|
10240
|
+
}
|
|
10241
|
+
if (need_sync) {
|
|
10242
|
+
ctx->unsynced_nodes_written.clear();
|
|
10243
|
+
ctx->unsynced_nodes_read.clear();
|
|
10244
|
+
ggml_vk_sync_buffers(ctx, compute_ctx);
|
|
10245
|
+
}
|
|
10246
|
+
// Add the last fused node and all fused source nodes to the unsynchronized list.
|
|
10247
|
+
const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
|
|
10248
|
+
ctx->unsynced_nodes_written.push_back(last_node);
|
|
10249
|
+
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
|
|
10250
|
+
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
|
|
10251
|
+
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
|
|
10252
|
+
if (!cur_node->src[j]) {
|
|
10253
|
+
continue;
|
|
10254
|
+
}
|
|
10255
|
+
ctx->unsynced_nodes_read.push_back(cur_node->src[j]);
|
|
10256
|
+
}
|
|
10257
|
+
}
|
|
10258
|
+
}
|
|
10259
|
+
|
|
9373
10260
|
switch (node->op) {
|
|
9374
10261
|
case GGML_OP_REPEAT:
|
|
9375
10262
|
ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
|
|
@@ -9388,8 +10275,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9388
10275
|
|
|
9389
10276
|
break;
|
|
9390
10277
|
case GGML_OP_ADD:
|
|
9391
|
-
|
|
9392
|
-
|
|
10278
|
+
if (ctx->num_additional_fused_ops) {
|
|
10279
|
+
ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx, dryrun);
|
|
10280
|
+
} else {
|
|
10281
|
+
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
10282
|
+
}
|
|
9393
10283
|
break;
|
|
9394
10284
|
case GGML_OP_SUB:
|
|
9395
10285
|
ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -9402,6 +10292,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9402
10292
|
case GGML_OP_DIV:
|
|
9403
10293
|
ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
9404
10294
|
|
|
10295
|
+
break;
|
|
10296
|
+
case GGML_OP_ADD_ID:
|
|
10297
|
+
ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
|
10298
|
+
|
|
9405
10299
|
break;
|
|
9406
10300
|
case GGML_OP_CONCAT:
|
|
9407
10301
|
ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -9418,6 +10312,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9418
10312
|
case GGML_OP_SQR:
|
|
9419
10313
|
ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun);
|
|
9420
10314
|
|
|
10315
|
+
break;
|
|
10316
|
+
case GGML_OP_SQRT:
|
|
10317
|
+
ggml_vk_sqrt(ctx, compute_ctx, src0, node, dryrun);
|
|
10318
|
+
|
|
9421
10319
|
break;
|
|
9422
10320
|
case GGML_OP_SIN:
|
|
9423
10321
|
ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun);
|
|
@@ -9481,6 +10379,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9481
10379
|
break;
|
|
9482
10380
|
case GGML_OP_UNARY:
|
|
9483
10381
|
switch (ggml_get_unary_op(node)) {
|
|
10382
|
+
case GGML_UNARY_OP_EXP:
|
|
9484
10383
|
case GGML_UNARY_OP_SILU:
|
|
9485
10384
|
case GGML_UNARY_OP_GELU:
|
|
9486
10385
|
case GGML_UNARY_OP_GELU_ERF:
|
|
@@ -9499,6 +10398,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9499
10398
|
case GGML_GLU_OP_GEGLU:
|
|
9500
10399
|
case GGML_GLU_OP_REGLU:
|
|
9501
10400
|
case GGML_GLU_OP_SWIGLU:
|
|
10401
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
9502
10402
|
case GGML_GLU_OP_GEGLU_ERF:
|
|
9503
10403
|
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9504
10404
|
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -9512,7 +10412,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9512
10412
|
|
|
9513
10413
|
break;
|
|
9514
10414
|
case GGML_OP_SOFT_MAX:
|
|
9515
|
-
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
10415
|
+
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
|
9516
10416
|
|
|
9517
10417
|
break;
|
|
9518
10418
|
case GGML_OP_SOFT_MAX_BACK:
|
|
@@ -9538,6 +10438,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9538
10438
|
case GGML_OP_SUM_ROWS:
|
|
9539
10439
|
ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
|
|
9540
10440
|
|
|
10441
|
+
break;
|
|
10442
|
+
case GGML_OP_MEAN:
|
|
10443
|
+
ggml_vk_mean(ctx, compute_ctx, src0, node, dryrun);
|
|
10444
|
+
|
|
9541
10445
|
break;
|
|
9542
10446
|
case GGML_OP_ARGMAX:
|
|
9543
10447
|
ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
|
|
@@ -9585,7 +10489,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9585
10489
|
break;
|
|
9586
10490
|
|
|
9587
10491
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
9588
|
-
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
|
|
10492
|
+
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node, dryrun);
|
|
9589
10493
|
|
|
9590
10494
|
break;
|
|
9591
10495
|
|
|
@@ -9602,6 +10506,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
9602
10506
|
case GGML_OP_OPT_STEP_ADAMW:
|
|
9603
10507
|
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
|
9604
10508
|
|
|
10509
|
+
break;
|
|
10510
|
+
|
|
10511
|
+
case GGML_OP_OPT_STEP_SGD:
|
|
10512
|
+
ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
|
10513
|
+
|
|
9605
10514
|
break;
|
|
9606
10515
|
default:
|
|
9607
10516
|
return false;
|
|
@@ -9658,10 +10567,12 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
9658
10567
|
case GGML_OP_SUB:
|
|
9659
10568
|
case GGML_OP_MUL:
|
|
9660
10569
|
case GGML_OP_DIV:
|
|
10570
|
+
case GGML_OP_ADD_ID:
|
|
9661
10571
|
case GGML_OP_CONCAT:
|
|
9662
10572
|
case GGML_OP_UPSCALE:
|
|
9663
10573
|
case GGML_OP_SCALE:
|
|
9664
10574
|
case GGML_OP_SQR:
|
|
10575
|
+
case GGML_OP_SQRT:
|
|
9665
10576
|
case GGML_OP_SIN:
|
|
9666
10577
|
case GGML_OP_COS:
|
|
9667
10578
|
case GGML_OP_CLAMP:
|
|
@@ -9690,6 +10601,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
9690
10601
|
case GGML_OP_ARGSORT:
|
|
9691
10602
|
case GGML_OP_SUM:
|
|
9692
10603
|
case GGML_OP_SUM_ROWS:
|
|
10604
|
+
case GGML_OP_MEAN:
|
|
9693
10605
|
case GGML_OP_ARGMAX:
|
|
9694
10606
|
case GGML_OP_COUNT_EQUAL:
|
|
9695
10607
|
case GGML_OP_IM2COL:
|
|
@@ -9704,11 +10616,12 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
9704
10616
|
case GGML_OP_REPEAT:
|
|
9705
10617
|
case GGML_OP_REPEAT_BACK:
|
|
9706
10618
|
case GGML_OP_OPT_STEP_ADAMW:
|
|
10619
|
+
case GGML_OP_OPT_STEP_SGD:
|
|
9707
10620
|
buf = tensor->buffer;
|
|
9708
|
-
|
|
9709
10621
|
break;
|
|
9710
10622
|
case GGML_OP_UNARY:
|
|
9711
10623
|
switch (ggml_get_unary_op(tensor)) {
|
|
10624
|
+
case GGML_UNARY_OP_EXP:
|
|
9712
10625
|
case GGML_UNARY_OP_SILU:
|
|
9713
10626
|
case GGML_UNARY_OP_GELU:
|
|
9714
10627
|
case GGML_UNARY_OP_GELU_ERF:
|
|
@@ -9727,6 +10640,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
9727
10640
|
case GGML_GLU_OP_GEGLU:
|
|
9728
10641
|
case GGML_GLU_OP_REGLU:
|
|
9729
10642
|
case GGML_GLU_OP_SWIGLU:
|
|
10643
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
9730
10644
|
case GGML_GLU_OP_GEGLU_ERF:
|
|
9731
10645
|
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9732
10646
|
buf = tensor->buffer;
|
|
@@ -9804,6 +10718,11 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
|
|
|
9804
10718
|
ggml_vk_pool_free(ctx, buffer);
|
|
9805
10719
|
}
|
|
9806
10720
|
ctx->gc.temp_buffers.clear();
|
|
10721
|
+
ctx->prealloc_y_last_pipeline_used = {};
|
|
10722
|
+
|
|
10723
|
+
ctx->unsynced_nodes_written.clear();
|
|
10724
|
+
ctx->unsynced_nodes_read.clear();
|
|
10725
|
+
ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
|
|
9807
10726
|
|
|
9808
10727
|
ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
|
|
9809
10728
|
ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
|
|
@@ -9839,6 +10758,7 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
|
|
9839
10758
|
ggml_vk_destroy_buffer(ctx->prealloc_x);
|
|
9840
10759
|
ggml_vk_destroy_buffer(ctx->prealloc_y);
|
|
9841
10760
|
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
|
|
10761
|
+
ctx->prealloc_y_last_pipeline_used = nullptr;
|
|
9842
10762
|
|
|
9843
10763
|
for (auto& buffer : ctx->buffer_pool) {
|
|
9844
10764
|
ggml_vk_destroy_buffer(buffer);
|
|
@@ -10259,6 +11179,58 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
|
|
|
10259
11179
|
return true;
|
|
10260
11180
|
}
|
|
10261
11181
|
|
|
11182
|
+
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
|
11183
|
+
|
|
11184
|
+
const ggml_tensor *first_node = cgraph->nodes[node_idx];
|
|
11185
|
+
if (first_node->op != GGML_OP_ADD) {
|
|
11186
|
+
return 0;
|
|
11187
|
+
}
|
|
11188
|
+
|
|
11189
|
+
if (!ctx->device->multi_add) {
|
|
11190
|
+
return 0;
|
|
11191
|
+
}
|
|
11192
|
+
|
|
11193
|
+
int32_t num_adds = 1;
|
|
11194
|
+
while (node_idx + num_adds < cgraph->n_nodes &&
|
|
11195
|
+
cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD &&
|
|
11196
|
+
num_adds < MAX_FUSED_ADDS) {
|
|
11197
|
+
num_adds++;
|
|
11198
|
+
}
|
|
11199
|
+
|
|
11200
|
+
// The shader currently requires same shapes (but different strides are allowed),
|
|
11201
|
+
// everything f32, and no misalignment
|
|
11202
|
+
for (int32_t i = 0; i < num_adds; ++i) {
|
|
11203
|
+
const ggml_tensor *next_node = cgraph->nodes[node_idx + i];
|
|
11204
|
+
if (!ggml_are_same_shape(first_node, next_node->src[0]) ||
|
|
11205
|
+
!ggml_are_same_shape(first_node, next_node->src[1]) ||
|
|
11206
|
+
next_node->type != GGML_TYPE_F32 ||
|
|
11207
|
+
next_node->src[0]->type != GGML_TYPE_F32 ||
|
|
11208
|
+
next_node->src[1]->type != GGML_TYPE_F32 ||
|
|
11209
|
+
get_misalign_bytes(ctx, next_node) ||
|
|
11210
|
+
get_misalign_bytes(ctx, next_node->src[0]) ||
|
|
11211
|
+
get_misalign_bytes(ctx, next_node->src[1])) {
|
|
11212
|
+
num_adds = i;
|
|
11213
|
+
}
|
|
11214
|
+
}
|
|
11215
|
+
|
|
11216
|
+
// Verify we can fuse these
|
|
11217
|
+
ggml_op adds[MAX_FUSED_ADDS];
|
|
11218
|
+
for (int32_t i = 0; i < num_adds; ++i) {
|
|
11219
|
+
adds[i] = GGML_OP_ADD;
|
|
11220
|
+
}
|
|
11221
|
+
|
|
11222
|
+
// decrease num_adds if they can't all be fused
|
|
11223
|
+
while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) {
|
|
11224
|
+
num_adds--;
|
|
11225
|
+
}
|
|
11226
|
+
|
|
11227
|
+
// a single add is not "fused", so just return zero
|
|
11228
|
+
if (num_adds == 1) {
|
|
11229
|
+
return 0;
|
|
11230
|
+
}
|
|
11231
|
+
return num_adds;
|
|
11232
|
+
}
|
|
11233
|
+
|
|
10262
11234
|
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
|
10263
11235
|
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
|
10264
11236
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
@@ -10270,10 +11242,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
10270
11242
|
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
|
|
10271
11243
|
}
|
|
10272
11244
|
|
|
11245
|
+
ctx->prealloc_size_add_rms_partials = 0;
|
|
11246
|
+
ctx->prealloc_size_add_rms_partials_offset = 0;
|
|
11247
|
+
ctx->do_add_rms_partials = false;
|
|
11248
|
+
|
|
10273
11249
|
uint64_t total_mat_mul_bytes = 0;
|
|
10274
11250
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
10275
|
-
if (!ctx->device->disable_fusion
|
|
10276
|
-
|
|
11251
|
+
if (!ctx->device->disable_fusion) {
|
|
11252
|
+
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
|
|
11253
|
+
if (num_adds) {
|
|
11254
|
+
ctx->num_additional_fused_ops = num_adds - 1;
|
|
11255
|
+
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
11256
|
+
ctx->num_additional_fused_ops = 1;
|
|
11257
|
+
}
|
|
10277
11258
|
}
|
|
10278
11259
|
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
|
10279
11260
|
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
|
@@ -10330,6 +11311,22 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
10330
11311
|
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
|
|
10331
11312
|
}
|
|
10332
11313
|
|
|
11314
|
+
ctx->prealloc_y_last_pipeline_used = nullptr;
|
|
11315
|
+
ctx->prealloc_y_last_tensor_used = nullptr;
|
|
11316
|
+
|
|
11317
|
+
if (ctx->prealloc_size_add_rms_partials) {
|
|
11318
|
+
if (ctx->compute_ctx.expired()) {
|
|
11319
|
+
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
|
11320
|
+
ctx->compute_ctx = compute_ctx;
|
|
11321
|
+
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
|
11322
|
+
} else {
|
|
11323
|
+
compute_ctx = ctx->compute_ctx.lock();
|
|
11324
|
+
}
|
|
11325
|
+
// initialize partial sums to zero.
|
|
11326
|
+
ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
|
|
11327
|
+
ggml_vk_sync_buffers(ctx, compute_ctx);
|
|
11328
|
+
}
|
|
11329
|
+
|
|
10333
11330
|
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
|
|
10334
11331
|
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
|
|
10335
11332
|
// (and scaled down based on model size, so smaller models submit earlier).
|
|
@@ -10348,8 +11345,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
10348
11345
|
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
10349
11346
|
}
|
|
10350
11347
|
|
|
10351
|
-
if (!ctx->device->disable_fusion
|
|
10352
|
-
|
|
11348
|
+
if (!ctx->device->disable_fusion) {
|
|
11349
|
+
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
|
|
11350
|
+
if (num_adds) {
|
|
11351
|
+
ctx->num_additional_fused_ops = num_adds - 1;
|
|
11352
|
+
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
11353
|
+
ctx->num_additional_fused_ops = 1;
|
|
11354
|
+
}
|
|
10353
11355
|
}
|
|
10354
11356
|
|
|
10355
11357
|
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
|
|
@@ -10456,10 +11458,10 @@ ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
|
|
|
10456
11458
|
ggml_vk_init(ctx, dev_num);
|
|
10457
11459
|
|
|
10458
11460
|
ggml_backend_t vk_backend = new ggml_backend {
|
|
10459
|
-
/* .guid
|
|
10460
|
-
/* .
|
|
10461
|
-
/* .device
|
|
10462
|
-
/* .context
|
|
11461
|
+
/* .guid = */ ggml_backend_vk_guid(),
|
|
11462
|
+
/* .iface = */ ggml_backend_vk_interface,
|
|
11463
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
|
|
11464
|
+
/* .context = */ ctx,
|
|
10463
11465
|
};
|
|
10464
11466
|
|
|
10465
11467
|
return vk_backend;
|
|
@@ -10556,6 +11558,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10556
11558
|
switch (op->op) {
|
|
10557
11559
|
case GGML_OP_UNARY:
|
|
10558
11560
|
switch (ggml_get_unary_op(op)) {
|
|
11561
|
+
case GGML_UNARY_OP_EXP:
|
|
10559
11562
|
case GGML_UNARY_OP_GELU:
|
|
10560
11563
|
case GGML_UNARY_OP_GELU_ERF:
|
|
10561
11564
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
@@ -10570,12 +11573,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10570
11573
|
default:
|
|
10571
11574
|
return false;
|
|
10572
11575
|
}
|
|
10573
|
-
break;
|
|
10574
11576
|
case GGML_OP_GLU:
|
|
10575
11577
|
switch (ggml_get_glu_op(op)) {
|
|
10576
11578
|
case GGML_GLU_OP_GEGLU:
|
|
10577
11579
|
case GGML_GLU_OP_REGLU:
|
|
10578
11580
|
case GGML_GLU_OP_SWIGLU:
|
|
11581
|
+
case GGML_GLU_OP_SWIGLU_OAI:
|
|
10579
11582
|
case GGML_GLU_OP_GEGLU_ERF:
|
|
10580
11583
|
case GGML_GLU_OP_GEGLU_QUICK:
|
|
10581
11584
|
return ggml_is_contiguous(op->src[0]) &&
|
|
@@ -10585,7 +11588,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10585
11588
|
default:
|
|
10586
11589
|
return false;
|
|
10587
11590
|
}
|
|
10588
|
-
break;
|
|
10589
11591
|
case GGML_OP_MUL_MAT:
|
|
10590
11592
|
case GGML_OP_MUL_MAT_ID:
|
|
10591
11593
|
{
|
|
@@ -10621,6 +11623,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10621
11623
|
case GGML_TYPE_IQ3_S:
|
|
10622
11624
|
case GGML_TYPE_IQ4_XS:
|
|
10623
11625
|
case GGML_TYPE_IQ4_NL:
|
|
11626
|
+
case GGML_TYPE_MXFP4:
|
|
10624
11627
|
break;
|
|
10625
11628
|
default:
|
|
10626
11629
|
return false;
|
|
@@ -10648,14 +11651,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10648
11651
|
}
|
|
10649
11652
|
|
|
10650
11653
|
return true;
|
|
10651
|
-
}
|
|
11654
|
+
}
|
|
10652
11655
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
10653
11656
|
{
|
|
10654
11657
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
10655
11658
|
auto device = ggml_vk_get_device(ctx->device);
|
|
10656
11659
|
bool coopmat2 = device->coopmat2;
|
|
10657
|
-
|
|
10658
|
-
|
|
11660
|
+
uint32_t HSK = op->src[1]->ne[0];
|
|
11661
|
+
uint32_t HSV = op->src[2]->ne[0];
|
|
11662
|
+
if ((HSK % 8) != 0 || (HSV % 8) != 0) {
|
|
11663
|
+
return false;
|
|
11664
|
+
}
|
|
11665
|
+
if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {
|
|
10659
11666
|
return false;
|
|
10660
11667
|
}
|
|
10661
11668
|
if (op->src[0]->type != GGML_TYPE_F32) {
|
|
@@ -10730,11 +11737,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10730
11737
|
case GGML_TYPE_IQ3_S:
|
|
10731
11738
|
case GGML_TYPE_IQ4_XS:
|
|
10732
11739
|
case GGML_TYPE_IQ4_NL:
|
|
11740
|
+
case GGML_TYPE_MXFP4:
|
|
10733
11741
|
return true;
|
|
10734
11742
|
default:
|
|
10735
11743
|
return false;
|
|
10736
11744
|
}
|
|
10737
|
-
}
|
|
11745
|
+
}
|
|
10738
11746
|
case GGML_OP_SET_ROWS:
|
|
10739
11747
|
{
|
|
10740
11748
|
switch (op->type) {
|
|
@@ -10751,7 +11759,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10751
11759
|
default:
|
|
10752
11760
|
return false;
|
|
10753
11761
|
}
|
|
10754
|
-
}
|
|
11762
|
+
}
|
|
10755
11763
|
case GGML_OP_CONT:
|
|
10756
11764
|
case GGML_OP_CPY:
|
|
10757
11765
|
case GGML_OP_DUP:
|
|
@@ -10803,7 +11811,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10803
11811
|
return true;
|
|
10804
11812
|
}
|
|
10805
11813
|
return false;
|
|
10806
|
-
}
|
|
11814
|
+
}
|
|
10807
11815
|
case GGML_OP_REPEAT:
|
|
10808
11816
|
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
|
|
10809
11817
|
case GGML_OP_REPEAT_BACK:
|
|
@@ -10828,13 +11836,22 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10828
11836
|
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
|
10829
11837
|
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
|
|
10830
11838
|
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
|
11839
|
+
case GGML_OP_ADD_ID:
|
|
11840
|
+
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
|
|
11841
|
+
op->type == GGML_TYPE_F32;
|
|
10831
11842
|
case GGML_OP_SILU_BACK:
|
|
10832
11843
|
case GGML_OP_RMS_NORM_BACK:
|
|
10833
11844
|
case GGML_OP_SQR:
|
|
11845
|
+
case GGML_OP_SQRT:
|
|
10834
11846
|
case GGML_OP_SIN:
|
|
10835
11847
|
case GGML_OP_COS:
|
|
10836
11848
|
case GGML_OP_CLAMP:
|
|
11849
|
+
case GGML_OP_LEAKY_RELU:
|
|
11850
|
+
case GGML_OP_OPT_STEP_ADAMW:
|
|
11851
|
+
case GGML_OP_OPT_STEP_SGD:
|
|
10837
11852
|
return op->src[0]->type == GGML_TYPE_F32;
|
|
11853
|
+
case GGML_OP_ARGSORT:
|
|
11854
|
+
return op->ne[0] <= max_argsort_cols;
|
|
10838
11855
|
case GGML_OP_UPSCALE:
|
|
10839
11856
|
case GGML_OP_ACC:
|
|
10840
11857
|
case GGML_OP_CONCAT:
|
|
@@ -10844,9 +11861,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10844
11861
|
case GGML_OP_DIAG_MASK_INF:
|
|
10845
11862
|
case GGML_OP_SOFT_MAX:
|
|
10846
11863
|
case GGML_OP_SOFT_MAX_BACK:
|
|
10847
|
-
|
|
11864
|
+
return true;
|
|
10848
11865
|
case GGML_OP_SUM:
|
|
10849
11866
|
case GGML_OP_SUM_ROWS:
|
|
11867
|
+
case GGML_OP_MEAN:
|
|
11868
|
+
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
|
|
10850
11869
|
case GGML_OP_ARGMAX:
|
|
10851
11870
|
case GGML_OP_COUNT_EQUAL:
|
|
10852
11871
|
case GGML_OP_IM2COL:
|
|
@@ -10855,8 +11874,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10855
11874
|
case GGML_OP_POOL_2D:
|
|
10856
11875
|
case GGML_OP_RWKV_WKV6:
|
|
10857
11876
|
case GGML_OP_RWKV_WKV7:
|
|
10858
|
-
case GGML_OP_LEAKY_RELU:
|
|
10859
|
-
case GGML_OP_OPT_STEP_ADAMW:
|
|
10860
11877
|
return true;
|
|
10861
11878
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
10862
11879
|
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
|
@@ -10865,14 +11882,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10865
11882
|
// Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
|
|
10866
11883
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
10867
11884
|
const vk_device& device = ggml_vk_get_device(ctx->device);
|
|
10868
|
-
bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
|
|
10869
11885
|
// Channel-contiguous format is not supported yet.
|
|
10870
|
-
return (op->src[0]->type == GGML_TYPE_F32 &&
|
|
11886
|
+
return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
|
10871
11887
|
op->src[1]->type == GGML_TYPE_F32 &&
|
|
10872
11888
|
op->type == GGML_TYPE_F32 &&
|
|
10873
11889
|
ggml_is_contiguous(op->src[0]) &&
|
|
10874
11890
|
ggml_is_contiguous(op->src[1]) &&
|
|
10875
|
-
ggml_is_contiguous(op))
|
|
11891
|
+
ggml_is_contiguous(op));
|
|
10876
11892
|
}
|
|
10877
11893
|
default:
|
|
10878
11894
|
return false;
|
|
@@ -11147,7 +12163,7 @@ size_t comp_nb[GGML_MAX_DIMS];
|
|
|
11147
12163
|
size_t check_counter = 0;
|
|
11148
12164
|
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
|
|
11149
12165
|
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
|
|
11150
|
-
if (tensor->op == GGML_OP_TRANSPOSE) {
|
|
12166
|
+
if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
|
|
11151
12167
|
return;
|
|
11152
12168
|
}
|
|
11153
12169
|
|
|
@@ -11246,6 +12262,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
11246
12262
|
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
|
11247
12263
|
const float * params = (const float *)tensor->op_params;
|
|
11248
12264
|
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
|
|
12265
|
+
if (src_clone[4]) {
|
|
12266
|
+
ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]);
|
|
12267
|
+
}
|
|
11249
12268
|
} else if (tensor->op == GGML_OP_MUL_MAT) {
|
|
11250
12269
|
tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
|
|
11251
12270
|
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
|
|
@@ -11264,12 +12283,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
11264
12283
|
} else if (tensor->op == GGML_OP_CONCAT) {
|
|
11265
12284
|
tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
|
|
11266
12285
|
} else if (tensor->op == GGML_OP_UPSCALE) {
|
|
11267
|
-
tensor_clone =
|
|
12286
|
+
tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
|
|
11268
12287
|
} else if (tensor->op == GGML_OP_SCALE) {
|
|
11269
12288
|
const float * params = (const float *)tensor->op_params;
|
|
11270
|
-
tensor_clone =
|
|
12289
|
+
tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
|
|
11271
12290
|
} else if (tensor->op == GGML_OP_SQR) {
|
|
11272
12291
|
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
|
|
12292
|
+
} else if (tensor->op == GGML_OP_SQRT) {
|
|
12293
|
+
tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]);
|
|
11273
12294
|
} else if (tensor->op == GGML_OP_SIN) {
|
|
11274
12295
|
tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
|
|
11275
12296
|
} else if (tensor->op == GGML_OP_COS) {
|
|
@@ -11340,6 +12361,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
11340
12361
|
}
|
|
11341
12362
|
} else if (tensor->op == GGML_OP_UNARY) {
|
|
11342
12363
|
switch (ggml_get_unary_op(tensor)) {
|
|
12364
|
+
case GGML_UNARY_OP_EXP:
|
|
12365
|
+
tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
|
|
12366
|
+
break;
|
|
11343
12367
|
case GGML_UNARY_OP_SILU:
|
|
11344
12368
|
tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
|
|
11345
12369
|
break;
|
|
@@ -11371,6 +12395,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
11371
12395
|
} else {
|
|
11372
12396
|
tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
|
|
11373
12397
|
}
|
|
12398
|
+
ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2));
|
|
12399
|
+
ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3));
|
|
11374
12400
|
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
|
|
11375
12401
|
if (src1 == nullptr) {
|
|
11376
12402
|
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
|
|
@@ -11378,8 +12404,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
11378
12404
|
} else {
|
|
11379
12405
|
tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
|
|
11380
12406
|
}
|
|
11381
|
-
} else if (tensor->op == GGML_OP_SET_ROWS) {
|
|
11382
|
-
tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
|
|
11383
12407
|
} else if (tensor->op == GGML_OP_CONT) {
|
|
11384
12408
|
tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
|
11385
12409
|
} else if (tensor->op == GGML_OP_RESHAPE) {
|
|
@@ -11399,6 +12423,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
11399
12423
|
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
|
|
11400
12424
|
} else if (tensor->op == GGML_OP_SUM_ROWS) {
|
|
11401
12425
|
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
|
|
12426
|
+
} else if (tensor->op == GGML_OP_MEAN) {
|
|
12427
|
+
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
|
|
11402
12428
|
} else if (tensor->op == GGML_OP_ARGMAX) {
|
|
11403
12429
|
tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
|
|
11404
12430
|
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {
|
|
@@ -11453,6 +12479,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
11453
12479
|
src_clone[0]->flags = src0->flags;
|
|
11454
12480
|
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
|
11455
12481
|
src_clone[2], src_clone[3], src_clone[4]);
|
|
12482
|
+
} else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
|
|
12483
|
+
src_clone[0]->flags = src0->flags;
|
|
12484
|
+
tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
|
|
12485
|
+
src_clone[2]);
|
|
12486
|
+
} else if (tensor->op == GGML_OP_ADD_ID) {
|
|
12487
|
+
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
|
|
11456
12488
|
}
|
|
11457
12489
|
else {
|
|
11458
12490
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
|
@@ -11487,14 +12519,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
11487
12519
|
|
|
11488
12520
|
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
|
|
11489
12521
|
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
|
|
11490
|
-
if (tensor->op == GGML_OP_TRANSPOSE) {
|
|
12522
|
+
if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
|
|
11491
12523
|
return;
|
|
11492
12524
|
}
|
|
11493
|
-
bool fused_rms_norm_mul = false;
|
|
11494
12525
|
if (ctx->num_additional_fused_ops == 1 &&
|
|
11495
12526
|
tensor->op == GGML_OP_RMS_NORM &&
|
|
11496
12527
|
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
|
|
11497
|
-
fused_rms_norm_mul = true;
|
|
11498
12528
|
tensor = cgraph->nodes[tensor_idx + 1];
|
|
11499
12529
|
}
|
|
11500
12530
|
|
|
@@ -11547,6 +12577,9 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
11547
12577
|
} else if (tensor->type == GGML_TYPE_F16) {
|
|
11548
12578
|
correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
|
|
11549
12579
|
result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
|
|
12580
|
+
} else if (tensor->type == GGML_TYPE_BF16) {
|
|
12581
|
+
correct = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
|
|
12582
|
+
result = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
|
|
11550
12583
|
} else if (tensor->type == GGML_TYPE_I32) {
|
|
11551
12584
|
correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
|
|
11552
12585
|
result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
|