@novastera-oss/llamarn 0.2.7 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -2
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +24 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +5 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
- package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -43
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
- package/cpp/llama.cpp/src/llama-arch.h +36 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
- package/cpp/llama.cpp/src/llama-batch.h +105 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
- package/cpp/llama.cpp/src/llama-graph.h +78 -79
- package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
- package/cpp/llama.cpp/src/llama-hparams.h +11 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
- package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +21 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
- package/cpp/llama.cpp/src/llama-model.h +40 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
- package/cpp/llama.cpp/src/llama-vocab.h +42 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +5 -0
- package/ios/include/llama.h +8 -43
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -224,6 +224,21 @@ enum vk_device_architecture {
|
|
|
224
224
|
INTEL_XE2,
|
|
225
225
|
};
|
|
226
226
|
|
|
227
|
+
// HSK x HSV
|
|
228
|
+
enum FaHeadSizes {
|
|
229
|
+
FA_HEAD_SIZE_64,
|
|
230
|
+
FA_HEAD_SIZE_80,
|
|
231
|
+
FA_HEAD_SIZE_96,
|
|
232
|
+
FA_HEAD_SIZE_112,
|
|
233
|
+
FA_HEAD_SIZE_128,
|
|
234
|
+
FA_HEAD_SIZE_192,
|
|
235
|
+
FA_HEAD_SIZE_192_128,
|
|
236
|
+
FA_HEAD_SIZE_256,
|
|
237
|
+
FA_HEAD_SIZE_576_512,
|
|
238
|
+
FA_HEAD_SIZE_UNSUPPORTED,
|
|
239
|
+
FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
|
|
240
|
+
};
|
|
241
|
+
|
|
227
242
|
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
|
228
243
|
vk::PhysicalDeviceProperties props = device.getProperties();
|
|
229
244
|
|
|
@@ -305,7 +320,7 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
|
|
|
305
320
|
}
|
|
306
321
|
|
|
307
322
|
struct vk_device_struct {
|
|
308
|
-
std::
|
|
323
|
+
std::recursive_mutex mutex;
|
|
309
324
|
|
|
310
325
|
vk::PhysicalDevice physical_device;
|
|
311
326
|
vk::PhysicalDeviceProperties properties;
|
|
@@ -410,32 +425,42 @@ struct vk_device_struct {
|
|
|
410
425
|
vk_pipeline pipeline_div_norepeat[2][2][2];
|
|
411
426
|
|
|
412
427
|
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
|
413
|
-
vk_pipeline
|
|
428
|
+
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
|
|
414
429
|
vk_pipeline pipeline_scale_f32;
|
|
415
430
|
vk_pipeline pipeline_sqr_f32;
|
|
416
431
|
vk_pipeline pipeline_sin_f32;
|
|
417
432
|
vk_pipeline pipeline_cos_f32;
|
|
418
433
|
vk_pipeline pipeline_clamp_f32;
|
|
419
434
|
vk_pipeline pipeline_pad_f32;
|
|
435
|
+
vk_pipeline pipeline_roll_f32;
|
|
420
436
|
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
|
421
437
|
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
|
|
422
438
|
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
|
|
423
439
|
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
|
424
440
|
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
|
441
|
+
vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
|
|
425
442
|
vk_pipeline pipeline_norm_f32;
|
|
426
443
|
vk_pipeline pipeline_group_norm_f32;
|
|
427
444
|
vk_pipeline pipeline_rms_norm_f32;
|
|
445
|
+
vk_pipeline pipeline_rms_norm_mul_f32;
|
|
428
446
|
vk_pipeline pipeline_rms_norm_back_f32;
|
|
429
447
|
vk_pipeline pipeline_l2_norm_f32;
|
|
430
448
|
|
|
431
449
|
// [src/dst 0=fp32,1=fp16]
|
|
432
450
|
vk_pipeline pipeline_gelu[2];
|
|
451
|
+
vk_pipeline pipeline_gelu_erf[2];
|
|
433
452
|
vk_pipeline pipeline_gelu_quick[2];
|
|
434
453
|
vk_pipeline pipeline_silu[2];
|
|
435
454
|
vk_pipeline pipeline_relu[2];
|
|
436
455
|
vk_pipeline pipeline_tanh[2];
|
|
437
456
|
vk_pipeline pipeline_sigmoid[2];
|
|
438
457
|
|
|
458
|
+
vk_pipeline pipeline_geglu[2];
|
|
459
|
+
vk_pipeline pipeline_reglu[2];
|
|
460
|
+
vk_pipeline pipeline_swiglu[2];
|
|
461
|
+
vk_pipeline pipeline_geglu_erf[2];
|
|
462
|
+
vk_pipeline pipeline_geglu_quick[2];
|
|
463
|
+
|
|
439
464
|
vk_pipeline pipeline_leaky_relu_f32;
|
|
440
465
|
vk_pipeline pipeline_silu_back_f32;
|
|
441
466
|
vk_pipeline pipeline_diag_mask_inf_f32;
|
|
@@ -461,26 +486,11 @@ struct vk_device_struct {
|
|
|
461
486
|
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
|
462
487
|
|
|
463
488
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
|
464
|
-
vk_pipeline
|
|
465
|
-
|
|
466
|
-
vk_pipeline
|
|
467
|
-
|
|
468
|
-
vk_pipeline
|
|
469
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
470
|
-
|
|
471
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
472
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
473
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
474
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
475
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
476
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
477
|
-
|
|
478
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
|
479
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
|
480
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
|
481
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
|
482
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
|
483
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
|
489
|
+
vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
|
490
|
+
|
|
491
|
+
vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
|
492
|
+
|
|
493
|
+
vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
|
484
494
|
|
|
485
495
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
|
486
496
|
|
|
@@ -493,6 +503,8 @@ struct vk_device_struct {
|
|
|
493
503
|
|
|
494
504
|
ggml_backend_buffer_type buffer_type;
|
|
495
505
|
|
|
506
|
+
bool disable_fusion;
|
|
507
|
+
|
|
496
508
|
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
|
497
509
|
std::unique_ptr<vk_memory_logger> memory_logger;
|
|
498
510
|
#endif
|
|
@@ -627,6 +639,8 @@ struct vk_flash_attn_push_constants {
|
|
|
627
639
|
uint32_t nev2;
|
|
628
640
|
uint32_t nev3;
|
|
629
641
|
uint32_t nem1;
|
|
642
|
+
uint32_t nem2;
|
|
643
|
+
uint32_t nem3;
|
|
630
644
|
|
|
631
645
|
uint32_t nb01;
|
|
632
646
|
uint32_t nb02;
|
|
@@ -637,14 +651,12 @@ struct vk_flash_attn_push_constants {
|
|
|
637
651
|
uint32_t nb21;
|
|
638
652
|
uint32_t nb22;
|
|
639
653
|
uint32_t nb23;
|
|
640
|
-
uint32_t nb31;
|
|
641
654
|
|
|
642
655
|
float scale;
|
|
643
656
|
float max_bias;
|
|
644
657
|
float logit_softcap;
|
|
645
658
|
|
|
646
|
-
uint32_t
|
|
647
|
-
uint32_t n_head_log2;
|
|
659
|
+
uint32_t mask_n_head_log2;
|
|
648
660
|
float m0;
|
|
649
661
|
float m1;
|
|
650
662
|
|
|
@@ -652,6 +664,7 @@ struct vk_flash_attn_push_constants {
|
|
|
652
664
|
uint32_t split_kv;
|
|
653
665
|
uint32_t k_num;
|
|
654
666
|
};
|
|
667
|
+
static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
|
|
655
668
|
|
|
656
669
|
struct vk_op_push_constants {
|
|
657
670
|
uint32_t KX;
|
|
@@ -660,6 +673,13 @@ struct vk_op_push_constants {
|
|
|
660
673
|
float param2;
|
|
661
674
|
};
|
|
662
675
|
|
|
676
|
+
struct vk_op_glu_push_constants {
|
|
677
|
+
uint32_t N;
|
|
678
|
+
uint32_t ne00;
|
|
679
|
+
uint32_t ne20;
|
|
680
|
+
uint32_t mode; // 0: default, 1: swapped, 2: split
|
|
681
|
+
};
|
|
682
|
+
|
|
663
683
|
struct vk_op_unary_push_constants {
|
|
664
684
|
uint32_t ne;
|
|
665
685
|
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
|
@@ -675,6 +695,37 @@ struct vk_op_unary_push_constants {
|
|
|
675
695
|
};
|
|
676
696
|
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
|
|
677
697
|
|
|
698
|
+
static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
|
|
699
|
+
GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
|
|
700
|
+
ne = ne != 0 ? ne : ggml_nelements(dst);
|
|
701
|
+
GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
|
|
702
|
+
|
|
703
|
+
vk_op_unary_push_constants p{};
|
|
704
|
+
p.ne = (uint32_t)ne;
|
|
705
|
+
|
|
706
|
+
size_t src0_tsize = ggml_type_size(src0->type);
|
|
707
|
+
p.ne00 = (uint32_t)src0->ne[0];
|
|
708
|
+
p.ne01 = (uint32_t)src0->ne[1];
|
|
709
|
+
p.ne02 = (uint32_t)src0->ne[2];
|
|
710
|
+
p.ne03 = (uint32_t)src0->ne[3];
|
|
711
|
+
p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
|
|
712
|
+
p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
|
|
713
|
+
p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
|
|
714
|
+
p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
|
|
715
|
+
|
|
716
|
+
size_t dst_tsize = ggml_type_size(dst->type);
|
|
717
|
+
p.ne10 = (uint32_t)dst->ne[0];
|
|
718
|
+
p.ne11 = (uint32_t)dst->ne[1];
|
|
719
|
+
p.ne12 = (uint32_t)dst->ne[2];
|
|
720
|
+
p.ne13 = (uint32_t)dst->ne[3];
|
|
721
|
+
p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
|
|
722
|
+
p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
|
|
723
|
+
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
|
|
724
|
+
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
|
|
725
|
+
|
|
726
|
+
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
|
|
727
|
+
}
|
|
728
|
+
|
|
678
729
|
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
|
679
730
|
// Precompute mp (m' in the paper) and L such that division
|
|
680
731
|
// can be computed using a multiply (high 32b of 64b result)
|
|
@@ -743,6 +794,14 @@ struct vk_op_rope_push_constants {
|
|
|
743
794
|
struct vk_op_soft_max_push_constants {
|
|
744
795
|
uint32_t KX;
|
|
745
796
|
uint32_t KY;
|
|
797
|
+
uint32_t ne00;
|
|
798
|
+
uint32_t ne01;
|
|
799
|
+
uint32_t ne02;
|
|
800
|
+
uint32_t ne12;
|
|
801
|
+
uint32_t ne13;
|
|
802
|
+
uint32_t nb11;
|
|
803
|
+
uint32_t nb12;
|
|
804
|
+
uint32_t nb13;
|
|
746
805
|
float scale;
|
|
747
806
|
float max_bias;
|
|
748
807
|
float m0;
|
|
@@ -836,6 +895,7 @@ struct vk_op_conv2d_dw_push_constants {
|
|
|
836
895
|
|
|
837
896
|
struct vk_op_upscale_push_constants {
|
|
838
897
|
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
|
|
898
|
+
uint32_t ne00; uint32_t ne01;
|
|
839
899
|
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
|
840
900
|
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
|
|
841
901
|
float sf0; float sf1; float sf2; float sf3;
|
|
@@ -978,6 +1038,10 @@ struct ggml_backend_vk_context {
|
|
|
978
1038
|
|
|
979
1039
|
vk_command_pool compute_cmd_pool;
|
|
980
1040
|
vk_command_pool transfer_cmd_pool;
|
|
1041
|
+
|
|
1042
|
+
// number of additional consecutive nodes that are being fused with the
|
|
1043
|
+
// node currently being processed
|
|
1044
|
+
int num_additional_fused_ops {};
|
|
981
1045
|
};
|
|
982
1046
|
|
|
983
1047
|
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
|
@@ -1041,6 +1105,14 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
|
|
|
1041
1105
|
struct vk_instance_t {
|
|
1042
1106
|
vk::Instance instance;
|
|
1043
1107
|
|
|
1108
|
+
bool debug_utils_support = false; // VK_EXT_debug_utils enabled
|
|
1109
|
+
PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {};
|
|
1110
|
+
PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {};
|
|
1111
|
+
PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {};
|
|
1112
|
+
PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {};
|
|
1113
|
+
PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {};
|
|
1114
|
+
PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {};
|
|
1115
|
+
|
|
1044
1116
|
std::vector<size_t> device_indices;
|
|
1045
1117
|
vk_device devices[GGML_VK_MAX_DEVICES];
|
|
1046
1118
|
};
|
|
@@ -1055,8 +1127,8 @@ static size_t vk_skip_checks;
|
|
|
1055
1127
|
static size_t vk_output_tensor;
|
|
1056
1128
|
|
|
1057
1129
|
static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
|
|
1058
|
-
static void ggml_vk_check_results_0(
|
|
1059
|
-
static void ggml_vk_check_results_1(
|
|
1130
|
+
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
|
|
1131
|
+
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
|
|
1060
1132
|
#endif
|
|
1061
1133
|
|
|
1062
1134
|
typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
|
@@ -1180,8 +1252,16 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
|
|
1180
1252
|
}
|
|
1181
1253
|
pipeline->compiled = true;
|
|
1182
1254
|
|
|
1255
|
+
if (vk_instance.debug_utils_support) {
|
|
1256
|
+
vk::DebugUtilsObjectNameInfoEXT duoni;
|
|
1257
|
+
duoni.objectType = vk::ObjectType::ePipeline;
|
|
1258
|
+
duoni.pObjectName = pipeline->name.c_str();
|
|
1259
|
+
duoni.objectHandle = reinterpret_cast<uint64_t>(static_cast<VkPipeline_T*>(pipeline->pipeline));
|
|
1260
|
+
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
|
|
1261
|
+
}
|
|
1262
|
+
|
|
1183
1263
|
{
|
|
1184
|
-
std::lock_guard<std::
|
|
1264
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
1185
1265
|
device->pipelines.insert({ pipeline->name, pipeline });
|
|
1186
1266
|
}
|
|
1187
1267
|
|
|
@@ -1395,7 +1475,7 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyPrope
|
|
|
1395
1475
|
|
|
1396
1476
|
static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
|
|
1397
1477
|
VK_LOG_DEBUG("ggml_vk_create_queue()");
|
|
1398
|
-
std::lock_guard<std::
|
|
1478
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
1399
1479
|
|
|
1400
1480
|
q.queue_family_index = queue_family_index;
|
|
1401
1481
|
q.transfer_only = transfer_only;
|
|
@@ -1657,10 +1737,46 @@ enum FaCodePath {
|
|
|
1657
1737
|
FA_COOPMAT2,
|
|
1658
1738
|
};
|
|
1659
1739
|
|
|
1740
|
+
static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
|
|
1741
|
+
if (hsk != 192 && hsk != 576 && hsk != hsv) {
|
|
1742
|
+
return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1743
|
+
}
|
|
1744
|
+
switch (hsk) {
|
|
1745
|
+
case 64: return FA_HEAD_SIZE_64;
|
|
1746
|
+
case 80: return FA_HEAD_SIZE_80;
|
|
1747
|
+
case 96: return FA_HEAD_SIZE_96;
|
|
1748
|
+
case 112: return FA_HEAD_SIZE_112;
|
|
1749
|
+
case 128: return FA_HEAD_SIZE_128;
|
|
1750
|
+
case 192:
|
|
1751
|
+
if (hsv == 192) {
|
|
1752
|
+
return FA_HEAD_SIZE_192;
|
|
1753
|
+
} else if (hsv == 128) {
|
|
1754
|
+
return FA_HEAD_SIZE_192_128;
|
|
1755
|
+
} else {
|
|
1756
|
+
return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1757
|
+
}
|
|
1758
|
+
case 256: return FA_HEAD_SIZE_256;
|
|
1759
|
+
case 576:
|
|
1760
|
+
if (hsv == 512) {
|
|
1761
|
+
return FA_HEAD_SIZE_576_512;
|
|
1762
|
+
} else {
|
|
1763
|
+
return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1764
|
+
}
|
|
1765
|
+
default: return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1766
|
+
}
|
|
1767
|
+
}
|
|
1768
|
+
|
|
1660
1769
|
// number of rows/cols for flash attention shader
|
|
1661
1770
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
|
1662
1771
|
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
|
1663
|
-
|
|
1772
|
+
|
|
1773
|
+
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
|
|
1774
|
+
if (hsv >= 512) {
|
|
1775
|
+
return 2;
|
|
1776
|
+
} else {
|
|
1777
|
+
return 8;
|
|
1778
|
+
}
|
|
1779
|
+
}
|
|
1664
1780
|
|
|
1665
1781
|
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
|
1666
1782
|
// 128 threads split into four subgroups, each subgroup does 1/4
|
|
@@ -1677,14 +1793,15 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
|
|
1677
1793
|
}
|
|
1678
1794
|
}
|
|
1679
1795
|
|
|
1680
|
-
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t
|
|
1796
|
+
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
|
|
1681
1797
|
GGML_UNUSED(clamp);
|
|
1798
|
+
GGML_UNUSED(hsv);
|
|
1682
1799
|
|
|
1683
1800
|
if (path == FA_SCALAR) {
|
|
1684
1801
|
if (small_rows) {
|
|
1685
1802
|
return {scalar_flash_attention_num_small_rows, 64};
|
|
1686
1803
|
} else {
|
|
1687
|
-
return {
|
|
1804
|
+
return {get_fa_scalar_num_large_rows(hsv), 32};
|
|
1688
1805
|
}
|
|
1689
1806
|
}
|
|
1690
1807
|
|
|
@@ -1702,8 +1819,12 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
|
|
|
1702
1819
|
}
|
|
1703
1820
|
|
|
1704
1821
|
// small cols to reduce register count
|
|
1705
|
-
if (ggml_is_quantized(type) ||
|
|
1706
|
-
|
|
1822
|
+
if (ggml_is_quantized(type) || hsk >= 256) {
|
|
1823
|
+
if (hsk >= 512) {
|
|
1824
|
+
return {32, 32};
|
|
1825
|
+
} else {
|
|
1826
|
+
return {64, 32};
|
|
1827
|
+
}
|
|
1707
1828
|
}
|
|
1708
1829
|
return {64, 64};
|
|
1709
1830
|
}
|
|
@@ -1745,7 +1866,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
1745
1866
|
const uint32_t warps = warptile[0] / warptile[10];
|
|
1746
1867
|
|
|
1747
1868
|
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
|
1748
|
-
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
|
|
1869
|
+
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
|
|
1749
1870
|
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
|
1750
1871
|
|
|
1751
1872
|
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
|
@@ -1870,10 +1991,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1870
1991
|
s_mmq_wg_denoms_k = { 32, 32, 1 };
|
|
1871
1992
|
|
|
1872
1993
|
// spec constants and tile sizes for quant matmul_id
|
|
1873
|
-
l_warptile_mmqid = { 256, 128,
|
|
1994
|
+
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
|
|
1874
1995
|
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
1875
1996
|
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
1876
|
-
l_mmqid_wg_denoms = { 128,
|
|
1997
|
+
l_mmqid_wg_denoms = { 128, 128, 1 };
|
|
1877
1998
|
m_mmqid_wg_denoms = { 128, 64, 1 };
|
|
1878
1999
|
s_mmqid_wg_denoms = { 128, 64, 1 };
|
|
1879
2000
|
|
|
@@ -1995,19 +2116,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1995
2116
|
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
|
1996
2117
|
};
|
|
1997
2118
|
|
|
1998
|
-
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t
|
|
1999
|
-
return {fa_rows_cols(path,
|
|
2119
|
+
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
|
2120
|
+
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
|
|
2000
2121
|
};
|
|
2001
2122
|
|
|
2002
|
-
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t
|
|
2123
|
+
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
|
2003
2124
|
// For large number of rows, 128 invocations seems to work best.
|
|
2004
2125
|
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
|
2005
2126
|
// can't use 256 for D==80.
|
|
2006
2127
|
// For scalar, use 128 (arbitrary)
|
|
2128
|
+
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
|
|
2129
|
+
const uint32_t D = (hsk|hsv);
|
|
2007
2130
|
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
|
2008
2131
|
? scalar_flash_attention_workgroup_size
|
|
2009
2132
|
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
|
2010
|
-
auto rows_cols = fa_rows_cols(path,
|
|
2133
|
+
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
|
|
2011
2134
|
|
|
2012
2135
|
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
|
2013
2136
|
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
|
@@ -2016,26 +2139,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2016
2139
|
|
|
2017
2140
|
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
|
2018
2141
|
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
|
2019
|
-
return {wg_size, rows_cols[0], rows_cols[1],
|
|
2142
|
+
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
|
|
2020
2143
|
};
|
|
2021
2144
|
|
|
2022
|
-
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX,
|
|
2023
|
-
ggml_vk_create_pipeline(device, device->
|
|
2024
|
-
ggml_vk_create_pipeline(device, device->
|
|
2025
|
-
ggml_vk_create_pipeline(device, device->
|
|
2026
|
-
ggml_vk_create_pipeline(device, device->
|
|
2027
|
-
ggml_vk_create_pipeline(device, device->
|
|
2028
|
-
ggml_vk_create_pipeline(device, device->
|
|
2029
|
-
ggml_vk_create_pipeline(device, device->
|
|
2030
|
-
ggml_vk_create_pipeline(device, device->
|
|
2145
|
+
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
|
|
2146
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2147
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2148
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2149
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2150
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2151
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2152
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2153
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2031
2154
|
|
|
2032
2155
|
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
|
2033
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
|
|
2034
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
|
|
2035
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
|
|
2036
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
|
|
2037
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
|
|
2038
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX,
|
|
2156
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
|
|
2157
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
|
|
2158
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
|
|
2159
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
|
|
2160
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
|
|
2161
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
|
|
2162
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
|
|
2163
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
|
|
2164
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
|
|
2039
2165
|
|
|
2040
2166
|
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
|
2041
2167
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
|
@@ -2625,7 +2751,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2625
2751
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2626
2752
|
|
|
2627
2753
|
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
|
2628
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2,
|
|
2754
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
|
2629
2755
|
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
|
2630
2756
|
|
|
2631
2757
|
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
|
@@ -2639,7 +2765,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2639
2765
|
|
|
2640
2766
|
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2641
2767
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2642
|
-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main",
|
|
2768
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
|
|
2769
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
|
|
2643
2770
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2644
2771
|
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2645
2772
|
|
|
@@ -2656,19 +2783,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2656
2783
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2657
2784
|
|
|
2658
2785
|
if (device->float_controls_rte_fp16) {
|
|
2659
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2660
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2661
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2662
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2663
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2664
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2786
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2787
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2788
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2789
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2790
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2791
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2792
|
+
} else {
|
|
2793
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2794
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2795
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2796
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2797
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2798
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2799
|
+
}
|
|
2800
|
+
|
|
2801
|
+
if (device->float_controls_rte_fp16) {
|
|
2802
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2803
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2804
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2805
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2806
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2807
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2808
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2809
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2810
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2665
2811
|
} else {
|
|
2666
|
-
ggml_vk_create_pipeline(device, device->
|
|
2667
|
-
ggml_vk_create_pipeline(device, device->
|
|
2668
|
-
ggml_vk_create_pipeline(device, device->
|
|
2669
|
-
ggml_vk_create_pipeline(device, device->
|
|
2670
|
-
ggml_vk_create_pipeline(device, device->
|
|
2671
|
-
ggml_vk_create_pipeline(device, device->
|
|
2812
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2813
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2814
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2815
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2816
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2817
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2818
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2819
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2820
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2672
2821
|
}
|
|
2673
2822
|
|
|
2674
2823
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
|
@@ -2708,7 +2857,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2708
2857
|
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
2709
2858
|
ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
2710
2859
|
|
|
2711
|
-
ggml_vk_create_pipeline(device, device->
|
|
2860
|
+
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
|
|
2861
|
+
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
|
|
2862
|
+
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
|
|
2712
2863
|
|
|
2713
2864
|
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2714
2865
|
|
|
@@ -2720,6 +2871,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2720
2871
|
|
|
2721
2872
|
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2722
2873
|
|
|
2874
|
+
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2875
|
+
|
|
2723
2876
|
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2724
2877
|
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2725
2878
|
|
|
@@ -2728,6 +2881,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2728
2881
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2729
2882
|
|
|
2730
2883
|
CREATE_UNARY(gelu)
|
|
2884
|
+
CREATE_UNARY(gelu_erf)
|
|
2731
2885
|
CREATE_UNARY(gelu_quick)
|
|
2732
2886
|
CREATE_UNARY(silu)
|
|
2733
2887
|
CREATE_UNARY(relu)
|
|
@@ -2735,6 +2889,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2735
2889
|
CREATE_UNARY(sigmoid)
|
|
2736
2890
|
#undef CREATE_UNARY
|
|
2737
2891
|
|
|
2892
|
+
#define CREATE_GLU(name) \
|
|
2893
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
|
2894
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
|
|
2895
|
+
|
|
2896
|
+
CREATE_GLU(geglu)
|
|
2897
|
+
CREATE_GLU(reglu)
|
|
2898
|
+
CREATE_GLU(swiglu)
|
|
2899
|
+
CREATE_GLU(geglu_erf)
|
|
2900
|
+
CREATE_GLU(geglu_quick)
|
|
2901
|
+
#undef CREATE_GLU
|
|
2902
|
+
|
|
2738
2903
|
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2739
2904
|
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2740
2905
|
|
|
@@ -3415,6 +3580,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
3415
3580
|
|
|
3416
3581
|
device->idx = idx;
|
|
3417
3582
|
|
|
3583
|
+
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
|
|
3584
|
+
|
|
3418
3585
|
return device;
|
|
3419
3586
|
}
|
|
3420
3587
|
|
|
@@ -3561,6 +3728,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
3561
3728
|
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
|
3562
3729
|
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
|
3563
3730
|
|
|
3731
|
+
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
|
|
3732
|
+
|
|
3564
3733
|
static void ggml_vk_instance_init() {
|
|
3565
3734
|
if (vk_instance_initialized) {
|
|
3566
3735
|
return;
|
|
@@ -3581,7 +3750,7 @@ static void ggml_vk_instance_init() {
|
|
|
3581
3750
|
#ifdef __APPLE__
|
|
3582
3751
|
const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
|
|
3583
3752
|
#endif
|
|
3584
|
-
|
|
3753
|
+
const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr;
|
|
3585
3754
|
std::vector<const char*> layers;
|
|
3586
3755
|
|
|
3587
3756
|
if (validation_ext) {
|
|
@@ -3596,6 +3765,9 @@ static void ggml_vk_instance_init() {
|
|
|
3596
3765
|
extensions.push_back("VK_KHR_portability_enumeration");
|
|
3597
3766
|
}
|
|
3598
3767
|
#endif
|
|
3768
|
+
if (debug_utils_ext) {
|
|
3769
|
+
extensions.push_back("VK_EXT_debug_utils");
|
|
3770
|
+
}
|
|
3599
3771
|
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
|
|
3600
3772
|
#ifdef __APPLE__
|
|
3601
3773
|
if (portability_enumeration_ext) {
|
|
@@ -3619,6 +3791,17 @@ static void ggml_vk_instance_init() {
|
|
|
3619
3791
|
vk_instance.instance = vk::createInstance(instance_create_info);
|
|
3620
3792
|
vk_instance_initialized = true;
|
|
3621
3793
|
|
|
3794
|
+
if (debug_utils_ext) {
|
|
3795
|
+
vk_instance.debug_utils_support = true;
|
|
3796
|
+
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT");
|
|
3797
|
+
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT");
|
|
3798
|
+
vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT");
|
|
3799
|
+
vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
|
|
3800
|
+
vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
|
|
3801
|
+
vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
|
|
3802
|
+
|
|
3803
|
+
}
|
|
3804
|
+
|
|
3622
3805
|
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
|
3623
3806
|
|
|
3624
3807
|
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
|
@@ -4091,6 +4274,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
|
|
|
4091
4274
|
return nullptr;
|
|
4092
4275
|
}
|
|
4093
4276
|
|
|
4277
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
4094
4278
|
device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
|
|
4095
4279
|
|
|
4096
4280
|
return buf->ptr;
|
|
@@ -4101,6 +4285,8 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
|
|
|
4101
4285
|
return;
|
|
4102
4286
|
}
|
|
4103
4287
|
VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
|
|
4288
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
4289
|
+
|
|
4104
4290
|
vk_buffer buf;
|
|
4105
4291
|
size_t index;
|
|
4106
4292
|
for (size_t i = 0; i < device->pinned_memory.size(); i++) {
|
|
@@ -4123,6 +4309,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
|
|
|
4123
4309
|
}
|
|
4124
4310
|
|
|
4125
4311
|
static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
|
|
4312
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
4126
4313
|
buf = nullptr;
|
|
4127
4314
|
buf_offset = 0;
|
|
4128
4315
|
for (size_t i = 0; i < device->pinned_memory.size(); i++) {
|
|
@@ -4424,7 +4611,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
|
|
|
4424
4611
|
memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
|
|
4425
4612
|
}
|
|
4426
4613
|
} else {
|
|
4427
|
-
std::lock_guard<std::
|
|
4614
|
+
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
|
|
4428
4615
|
|
|
4429
4616
|
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
|
|
4430
4617
|
ggml_vk_ctx_begin(dst->device, subctx);
|
|
@@ -4515,7 +4702,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
|
|
|
4515
4702
|
|
|
4516
4703
|
memcpy(dst, (uint8_t *) src->ptr + offset, size);
|
|
4517
4704
|
} else {
|
|
4518
|
-
std::lock_guard<std::
|
|
4705
|
+
std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
|
|
4519
4706
|
|
|
4520
4707
|
vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
|
|
4521
4708
|
ggml_vk_ctx_begin(src->device, subctx);
|
|
@@ -4545,7 +4732,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds
|
|
|
4545
4732
|
|
|
4546
4733
|
static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
|
|
4547
4734
|
if (src->device == dst->device) {
|
|
4548
|
-
std::lock_guard<std::
|
|
4735
|
+
std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
|
|
4549
4736
|
VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
|
|
4550
4737
|
// Copy within the device
|
|
4551
4738
|
vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
|
|
@@ -4580,7 +4767,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t
|
|
|
4580
4767
|
static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
|
|
4581
4768
|
VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
|
|
4582
4769
|
|
|
4583
|
-
std::lock_guard<std::
|
|
4770
|
+
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
|
|
4584
4771
|
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
|
|
4585
4772
|
ggml_vk_ctx_begin(dst->device, subctx);
|
|
4586
4773
|
subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
|
|
@@ -4807,9 +4994,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
4807
4994
|
// type size must be exactly 2 or 4.
|
|
4808
4995
|
GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
|
|
4809
4996
|
if ((ggml_type_size(src->type) % 4) == 0) {
|
|
4810
|
-
|
|
4997
|
+
if (contig) {
|
|
4998
|
+
return ctx->device->pipeline_contig_cpy_f32_f32;
|
|
4999
|
+
} else {
|
|
5000
|
+
return ctx->device->pipeline_cpy_f32_f32;
|
|
5001
|
+
}
|
|
4811
5002
|
} else {
|
|
4812
|
-
|
|
5003
|
+
if (contig) {
|
|
5004
|
+
return ctx->device->pipeline_contig_cpy_f16_f16;
|
|
5005
|
+
} else {
|
|
5006
|
+
return ctx->device->pipeline_cpy_f16_f16;
|
|
5007
|
+
}
|
|
4813
5008
|
}
|
|
4814
5009
|
}
|
|
4815
5010
|
|
|
@@ -4870,7 +5065,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4870
5065
|
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
|
4871
5066
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
4872
5067
|
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
|
|
4873
|
-
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
|
|
5068
|
+
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
|
|
4874
5069
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
|
4875
5070
|
|
|
4876
5071
|
const uint64_t ne00 = src0->ne[0];
|
|
@@ -5098,7 +5293,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
5098
5293
|
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
|
5099
5294
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
5100
5295
|
std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)");
|
|
5101
|
-
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
|
|
5296
|
+
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
|
|
5102
5297
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
|
5103
5298
|
|
|
5104
5299
|
const uint64_t ne00 = src0->ne[0];
|
|
@@ -5699,7 +5894,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
5699
5894
|
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
|
|
5700
5895
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
5701
5896
|
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
|
|
5702
|
-
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
|
|
5897
|
+
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
|
|
5703
5898
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
|
5704
5899
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
|
5705
5900
|
|
|
@@ -5893,14 +6088,60 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5893
6088
|
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
|
5894
6089
|
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
|
|
5895
6090
|
} else {
|
|
5896
|
-
|
|
6091
|
+
// Split based on number of ids, to fit in shared memory
|
|
6092
|
+
const uint32_t nei0 = (uint32_t)src2->ne[0];
|
|
6093
|
+
const uint32_t nei1 = (uint32_t)src2->ne[1];
|
|
6094
|
+
|
|
6095
|
+
GGML_ASSERT(nei0 <= 4096);
|
|
6096
|
+
const uint32_t split_size = std::min(nei1, 4096u / nei0);
|
|
6097
|
+
|
|
6098
|
+
ggml_tensor src1_copy = *src1;
|
|
6099
|
+
ggml_tensor src2_copy = *src2;
|
|
6100
|
+
ggml_tensor dst_copy = *dst;
|
|
6101
|
+
|
|
6102
|
+
for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
|
|
6103
|
+
const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
|
|
6104
|
+
|
|
6105
|
+
src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
|
|
6106
|
+
src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
|
|
6107
|
+
dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
|
|
6108
|
+
|
|
6109
|
+
src1_copy.ne[2] = n_tokens;
|
|
6110
|
+
src2_copy.ne[1] = n_tokens;
|
|
6111
|
+
dst_copy.ne[2] = n_tokens;
|
|
6112
|
+
|
|
6113
|
+
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
|
|
6114
|
+
}
|
|
5897
6115
|
}
|
|
5898
6116
|
}
|
|
5899
6117
|
|
|
5900
|
-
static bool
|
|
6118
|
+
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
|
|
5901
6119
|
// Needs to be kept up to date on shader changes
|
|
6120
|
+
GGML_UNUSED(hsv);
|
|
5902
6121
|
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
|
5903
|
-
const uint32_t Br =
|
|
6122
|
+
const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
|
|
6123
|
+
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
6124
|
+
|
|
6125
|
+
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
6126
|
+
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
|
|
6127
|
+
|
|
6128
|
+
const uint32_t masksh = Bc * Br * sizeof(float);
|
|
6129
|
+
|
|
6130
|
+
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
|
|
6131
|
+
|
|
6132
|
+
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
|
|
6133
|
+
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
6134
|
+
|
|
6135
|
+
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
|
6136
|
+
|
|
6137
|
+
return supported;
|
|
6138
|
+
}
|
|
6139
|
+
|
|
6140
|
+
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
|
|
6141
|
+
// Needs to be kept up to date on shader changes
|
|
6142
|
+
GGML_UNUSED(hsv);
|
|
6143
|
+
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
|
6144
|
+
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
|
|
5904
6145
|
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
5905
6146
|
|
|
5906
6147
|
const uint32_t acctype = f32acc ? 4 : 2;
|
|
@@ -5909,12 +6150,12 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|
|
5909
6150
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
5910
6151
|
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
|
5911
6152
|
|
|
5912
|
-
const uint32_t Qf = Br * (
|
|
6153
|
+
const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
|
|
5913
6154
|
|
|
5914
|
-
const uint32_t sfshstride = (
|
|
6155
|
+
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
|
|
5915
6156
|
const uint32_t sfsh = Bc * sfshstride * acctype;
|
|
5916
6157
|
|
|
5917
|
-
const uint32_t kshstride =
|
|
6158
|
+
const uint32_t kshstride = hsk / 4 + 2;
|
|
5918
6159
|
const uint32_t ksh = Bc * kshstride * f16vec4;
|
|
5919
6160
|
|
|
5920
6161
|
const uint32_t slope = Br * sizeof(float);
|
|
@@ -5922,7 +6163,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|
|
5922
6163
|
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
|
5923
6164
|
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
5924
6165
|
|
|
5925
|
-
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(
|
|
6166
|
+
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
|
5926
6167
|
|
|
5927
6168
|
return supported;
|
|
5928
6169
|
}
|
|
@@ -5944,13 +6185,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5944
6185
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
5945
6186
|
|
|
5946
6187
|
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
|
5947
|
-
const uint32_t
|
|
6188
|
+
const uint32_t nem2 = mask ? mask->ne[2] : 0;
|
|
6189
|
+
const uint32_t nem3 = mask ? mask->ne[3] : 0;
|
|
5948
6190
|
|
|
5949
|
-
const uint32_t
|
|
6191
|
+
const uint32_t HSK = nek0;
|
|
6192
|
+
const uint32_t HSV = nev0;
|
|
5950
6193
|
uint32_t N = neq1;
|
|
5951
6194
|
const uint32_t KV = nek1;
|
|
5952
6195
|
|
|
5953
|
-
GGML_ASSERT(ne0 ==
|
|
6196
|
+
GGML_ASSERT(ne0 == HSV);
|
|
5954
6197
|
GGML_ASSERT(ne2 == N);
|
|
5955
6198
|
|
|
5956
6199
|
// input tensor rows must be contiguous
|
|
@@ -5958,12 +6201,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5958
6201
|
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
|
5959
6202
|
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
|
5960
6203
|
|
|
5961
|
-
GGML_ASSERT(neq0 ==
|
|
5962
|
-
GGML_ASSERT(nek0 == D);
|
|
5963
|
-
GGML_ASSERT(nev0 == D);
|
|
6204
|
+
GGML_ASSERT(neq0 == HSK);
|
|
5964
6205
|
|
|
5965
6206
|
GGML_ASSERT(neq1 == N);
|
|
5966
|
-
GGML_ASSERT(nev0 == D);
|
|
5967
6207
|
|
|
5968
6208
|
GGML_ASSERT(nev1 == nek1);
|
|
5969
6209
|
|
|
@@ -5984,7 +6224,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5984
6224
|
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
|
5985
6225
|
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
|
5986
6226
|
|
|
5987
|
-
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device,
|
|
6227
|
+
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
|
|
5988
6228
|
|
|
5989
6229
|
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
|
5990
6230
|
path = FA_SCALAR;
|
|
@@ -6004,7 +6244,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6004
6244
|
case FA_SCALAR:
|
|
6005
6245
|
case FA_COOPMAT1:
|
|
6006
6246
|
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
|
6007
|
-
max_gqa =
|
|
6247
|
+
max_gqa = get_fa_scalar_num_large_rows(HSV);
|
|
6008
6248
|
break;
|
|
6009
6249
|
case FA_COOPMAT2:
|
|
6010
6250
|
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
|
@@ -6014,7 +6254,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6014
6254
|
}
|
|
6015
6255
|
|
|
6016
6256
|
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
|
6017
|
-
qk_ratio * nek2 == neq2 && nek2 == nev2 &&
|
|
6257
|
+
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
|
|
6018
6258
|
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
|
6019
6259
|
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
|
6020
6260
|
// and change addressing calculations to index Q's dimension 2.
|
|
@@ -6037,47 +6277,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6037
6277
|
path = FA_SCALAR;
|
|
6038
6278
|
}
|
|
6039
6279
|
|
|
6280
|
+
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
|
|
6281
|
+
if (path == FA_SCALAR &&
|
|
6282
|
+
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
|
|
6283
|
+
small_rows = true;
|
|
6284
|
+
}
|
|
6285
|
+
|
|
6040
6286
|
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
|
6041
6287
|
|
|
6288
|
+
FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
|
|
6289
|
+
|
|
6042
6290
|
switch (path) {
|
|
6043
6291
|
case FA_SCALAR:
|
|
6044
|
-
|
|
6045
|
-
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
|
6046
|
-
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
|
6047
|
-
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
|
6048
|
-
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
|
6049
|
-
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
|
6050
|
-
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
|
6051
|
-
default:
|
|
6052
|
-
GGML_ASSERT(!"unsupported D value");
|
|
6053
|
-
return;
|
|
6054
|
-
}
|
|
6292
|
+
pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
|
|
6055
6293
|
break;
|
|
6056
6294
|
case FA_COOPMAT1:
|
|
6057
|
-
|
|
6058
|
-
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6059
|
-
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6060
|
-
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6061
|
-
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6062
|
-
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6063
|
-
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6064
|
-
default:
|
|
6065
|
-
GGML_ASSERT(!"unsupported D value");
|
|
6066
|
-
return;
|
|
6067
|
-
}
|
|
6295
|
+
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
|
|
6068
6296
|
break;
|
|
6069
6297
|
case FA_COOPMAT2:
|
|
6070
|
-
|
|
6071
|
-
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6072
|
-
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6073
|
-
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6074
|
-
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6075
|
-
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6076
|
-
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6077
|
-
default:
|
|
6078
|
-
GGML_ASSERT(!"unsupported D value");
|
|
6079
|
-
return;
|
|
6080
|
-
}
|
|
6298
|
+
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
|
|
6081
6299
|
break;
|
|
6082
6300
|
default:
|
|
6083
6301
|
GGML_ASSERT(0);
|
|
@@ -6105,21 +6323,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6105
6323
|
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
|
6106
6324
|
|
|
6107
6325
|
// Try to use split_k when KV is large enough to be worth the overhead
|
|
6108
|
-
if (workgroups_x == 1 && shader_core_count > 0
|
|
6326
|
+
if (workgroups_x == 1 && shader_core_count > 0) {
|
|
6109
6327
|
// Try to run two workgroups per SM.
|
|
6110
|
-
split_k =
|
|
6328
|
+
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
|
6111
6329
|
if (split_k > 1) {
|
|
6112
6330
|
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
|
6113
6331
|
// of "align", so recompute split_k based on that.
|
|
6114
|
-
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
|
|
6332
|
+
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
|
|
6115
6333
|
split_k = CEIL_DIV(KV, split_kv);
|
|
6116
6334
|
workgroups_x = split_k;
|
|
6117
6335
|
}
|
|
6118
6336
|
}
|
|
6119
6337
|
|
|
6120
|
-
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
|
|
6121
|
-
// and the per-row m and L values (ne1 rows).
|
|
6122
|
-
const uint64_t split_k_size = split_k > 1 ? (
|
|
6338
|
+
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
|
6339
|
+
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
|
6340
|
+
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
|
6123
6341
|
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
|
6124
6342
|
GGML_ABORT("Requested preallocation size is too large");
|
|
6125
6343
|
}
|
|
@@ -6206,18 +6424,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6206
6424
|
}
|
|
6207
6425
|
}
|
|
6208
6426
|
|
|
6427
|
+
uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
|
|
6428
|
+
|
|
6209
6429
|
const vk_flash_attn_push_constants pc = { N, KV,
|
|
6210
6430
|
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
|
6211
6431
|
(uint32_t)neq2, (uint32_t)neq3,
|
|
6212
6432
|
(uint32_t)nek2, (uint32_t)nek3,
|
|
6213
6433
|
(uint32_t)nev2, (uint32_t)nev3,
|
|
6214
|
-
nem1,
|
|
6434
|
+
nem1, nem2, nem3,
|
|
6215
6435
|
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
|
|
6216
6436
|
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
|
|
6217
6437
|
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
|
6218
|
-
nbm1,
|
|
6219
6438
|
scale, max_bias, logit_softcap,
|
|
6220
|
-
|
|
6439
|
+
mask_n_head_log2, m0, m1,
|
|
6221
6440
|
gqa_ratio, split_kv, split_k };
|
|
6222
6441
|
|
|
6223
6442
|
ggml_vk_sync_buffers(subctx);
|
|
@@ -6238,13 +6457,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6238
6457
|
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
|
6239
6458
|
|
|
6240
6459
|
ggml_vk_sync_buffers(subctx);
|
|
6241
|
-
const std::array<uint32_t,
|
|
6460
|
+
const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
|
|
6242
6461
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
|
6243
6462
|
{
|
|
6244
6463
|
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
|
6245
6464
|
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
|
6246
6465
|
},
|
|
6247
|
-
pc2, { (uint32_t)ne1,
|
|
6466
|
+
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
|
6248
6467
|
} else {
|
|
6249
6468
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
6250
6469
|
{
|
|
@@ -6320,8 +6539,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6320
6539
|
}
|
|
6321
6540
|
return nullptr;
|
|
6322
6541
|
case GGML_OP_UPSCALE:
|
|
6323
|
-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
6324
|
-
|
|
6542
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6543
|
+
int mode = ggml_get_op_params_i32(dst, 0);
|
|
6544
|
+
switch (mode) {
|
|
6545
|
+
case GGML_SCALE_MODE_NEAREST:
|
|
6546
|
+
return ctx->device->pipeline_upscale_nearest_f32;
|
|
6547
|
+
case GGML_SCALE_MODE_BILINEAR:
|
|
6548
|
+
return ctx->device->pipeline_upscale_bilinear_f32;
|
|
6549
|
+
case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
|
|
6550
|
+
return ctx->device->pipeline_upscale_bilinear_ac_f32;
|
|
6551
|
+
}
|
|
6325
6552
|
}
|
|
6326
6553
|
return nullptr;
|
|
6327
6554
|
case GGML_OP_SCALE:
|
|
@@ -6354,6 +6581,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6354
6581
|
return ctx->device->pipeline_pad_f32;
|
|
6355
6582
|
}
|
|
6356
6583
|
return nullptr;
|
|
6584
|
+
case GGML_OP_ROLL:
|
|
6585
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6586
|
+
return ctx->device->pipeline_roll_f32;
|
|
6587
|
+
}
|
|
6588
|
+
return nullptr;
|
|
6357
6589
|
case GGML_OP_REPEAT:
|
|
6358
6590
|
if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
|
|
6359
6591
|
return ctx->device->pipeline_repeat_f32;
|
|
@@ -6368,6 +6600,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6368
6600
|
case GGML_OP_CONT:
|
|
6369
6601
|
case GGML_OP_DUP:
|
|
6370
6602
|
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
|
|
6603
|
+
case GGML_OP_SET_ROWS:
|
|
6604
|
+
return ctx->device->pipeline_set_rows[dst->type];
|
|
6371
6605
|
case GGML_OP_SILU_BACK:
|
|
6372
6606
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6373
6607
|
return ctx->device->pipeline_silu_back_f32;
|
|
@@ -6385,7 +6619,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6385
6619
|
return nullptr;
|
|
6386
6620
|
case GGML_OP_RMS_NORM:
|
|
6387
6621
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6388
|
-
return ctx->device->pipeline_rms_norm_f32;
|
|
6622
|
+
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
|
|
6389
6623
|
}
|
|
6390
6624
|
return nullptr;
|
|
6391
6625
|
case GGML_OP_RMS_NORM_BACK:
|
|
@@ -6410,6 +6644,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6410
6644
|
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
|
|
6411
6645
|
case GGML_UNARY_OP_GELU:
|
|
6412
6646
|
return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
|
|
6647
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
6648
|
+
return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
|
|
6413
6649
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
6414
6650
|
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
|
|
6415
6651
|
case GGML_UNARY_OP_RELU:
|
|
@@ -6422,6 +6658,28 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6422
6658
|
break;
|
|
6423
6659
|
}
|
|
6424
6660
|
return nullptr;
|
|
6661
|
+
case GGML_OP_GLU:
|
|
6662
|
+
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
|
|
6663
|
+
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
|
|
6664
|
+
(src0->type != dst->type)) {
|
|
6665
|
+
return nullptr;
|
|
6666
|
+
}
|
|
6667
|
+
|
|
6668
|
+
switch (ggml_get_glu_op(dst)) {
|
|
6669
|
+
case GGML_GLU_OP_GEGLU:
|
|
6670
|
+
return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
|
|
6671
|
+
case GGML_GLU_OP_REGLU:
|
|
6672
|
+
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
|
6673
|
+
case GGML_GLU_OP_SWIGLU:
|
|
6674
|
+
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
|
6675
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
6676
|
+
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
|
|
6677
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
6678
|
+
return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
|
|
6679
|
+
default:
|
|
6680
|
+
break;
|
|
6681
|
+
}
|
|
6682
|
+
return nullptr;
|
|
6425
6683
|
case GGML_OP_DIAG_MASK_INF:
|
|
6426
6684
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6427
6685
|
return ctx->device->pipeline_diag_mask_inf_f32;
|
|
@@ -6582,6 +6840,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
|
6582
6840
|
case GGML_OP_RMS_NORM:
|
|
6583
6841
|
case GGML_OP_CONV_2D_DW:
|
|
6584
6842
|
case GGML_OP_IM2COL:
|
|
6843
|
+
case GGML_OP_SET_ROWS:
|
|
6585
6844
|
return true;
|
|
6586
6845
|
default:
|
|
6587
6846
|
return false;
|
|
@@ -6876,12 +7135,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6876
7135
|
case GGML_OP_COS:
|
|
6877
7136
|
case GGML_OP_CLAMP:
|
|
6878
7137
|
case GGML_OP_PAD:
|
|
7138
|
+
case GGML_OP_ROLL:
|
|
6879
7139
|
case GGML_OP_REPEAT:
|
|
6880
7140
|
case GGML_OP_REPEAT_BACK:
|
|
6881
7141
|
case GGML_OP_CPY:
|
|
6882
7142
|
case GGML_OP_CONCAT:
|
|
6883
7143
|
case GGML_OP_UPSCALE:
|
|
6884
7144
|
case GGML_OP_UNARY:
|
|
7145
|
+
case GGML_OP_GLU:
|
|
6885
7146
|
case GGML_OP_CONV_2D_DW:
|
|
6886
7147
|
{
|
|
6887
7148
|
uint32_t ne = ggml_nelements(dst);
|
|
@@ -6894,6 +7155,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6894
7155
|
ne *= ggml_type_size(src0->type) / 2;
|
|
6895
7156
|
}
|
|
6896
7157
|
}
|
|
7158
|
+
// copy_to_quant has block size of 32, and each thread does QUANT_K elements.
|
|
7159
|
+
// Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
|
|
7160
|
+
// So divide by block size here before splitting into 512x512 groups.
|
|
7161
|
+
if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
|
7162
|
+
ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
|
|
7163
|
+
}
|
|
6897
7164
|
if (ne > 262144) {
|
|
6898
7165
|
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
|
6899
7166
|
} else if (ne > 512) {
|
|
@@ -6902,6 +7169,25 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6902
7169
|
elements = { ne, 1, 1 };
|
|
6903
7170
|
}
|
|
6904
7171
|
} break;
|
|
7172
|
+
case GGML_OP_SET_ROWS:
|
|
7173
|
+
{
|
|
7174
|
+
uint32_t ne = ggml_nelements(src0);
|
|
7175
|
+
if (ggml_is_quantized(dst->type)) {
|
|
7176
|
+
// quants run 32 threads each doing QUANT_K elements
|
|
7177
|
+
ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
|
|
7178
|
+
} else {
|
|
7179
|
+
// scalar types do one element per thread, running 512 threads
|
|
7180
|
+
ne = CEIL_DIV(ne, 512);
|
|
7181
|
+
}
|
|
7182
|
+
if (ne > 262144) {
|
|
7183
|
+
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
|
7184
|
+
} else if (ne > 512) {
|
|
7185
|
+
elements = { 512, CEIL_DIV(ne, 512), 1 };
|
|
7186
|
+
} else {
|
|
7187
|
+
elements = { ne, 1, 1 };
|
|
7188
|
+
}
|
|
7189
|
+
}
|
|
7190
|
+
break;
|
|
6905
7191
|
default:
|
|
6906
7192
|
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
|
|
6907
7193
|
break;
|
|
@@ -6922,7 +7208,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6922
7208
|
}
|
|
6923
7209
|
}
|
|
6924
7210
|
|
|
6925
|
-
if (op == GGML_OP_SOFT_MAX) {
|
|
7211
|
+
if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
|
|
6926
7212
|
// Empty src1 is possible in soft_max, but the shader needs a buffer
|
|
6927
7213
|
vk_subbuffer subbuf_y;
|
|
6928
7214
|
if (use_src1) {
|
|
@@ -7311,14 +7597,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
7311
7597
|
|
|
7312
7598
|
static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7313
7599
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7600
|
+
const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
|
|
7314
7601
|
|
|
7315
|
-
|
|
7316
|
-
|
|
7317
|
-
|
|
7318
|
-
|
|
7602
|
+
float sf0 = (float)dst->ne[0] / src0->ne[0];
|
|
7603
|
+
float sf1 = (float)dst->ne[1] / src0->ne[1];
|
|
7604
|
+
float sf2 = (float)dst->ne[2] / src0->ne[2];
|
|
7605
|
+
float sf3 = (float)dst->ne[3] / src0->ne[3];
|
|
7606
|
+
|
|
7607
|
+
if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7608
|
+
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
|
|
7609
|
+
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
|
|
7610
|
+
}
|
|
7319
7611
|
|
|
7320
7612
|
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
|
|
7321
7613
|
(uint32_t)ggml_nelements(dst), 0, 0,
|
|
7614
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
|
|
7322
7615
|
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7323
7616
|
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
|
|
7324
7617
|
sf0, sf1, sf2, sf3,
|
|
@@ -7326,123 +7619,64 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
7326
7619
|
}
|
|
7327
7620
|
|
|
7328
7621
|
static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7329
|
-
|
|
7330
|
-
|
|
7331
|
-
|
|
7622
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
|
7623
|
+
p.param1 = ggml_get_op_params_f32(dst, 0);
|
|
7624
|
+
p.param2 = ggml_get_op_params_f32(dst, 1);
|
|
7332
7625
|
|
|
7333
|
-
ggml_vk_op_f32
|
|
7334
|
-
(uint32_t)ggml_nelements(src0),
|
|
7335
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7336
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7337
|
-
0,
|
|
7338
|
-
op_params[0], 0.0f,
|
|
7339
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7340
|
-
}, dryrun);
|
|
7626
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
|
|
7341
7627
|
}
|
|
7342
7628
|
|
|
7343
7629
|
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7344
|
-
|
|
7345
|
-
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7346
|
-
|
|
7347
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
|
|
7348
|
-
(uint32_t)ggml_nelements(src0),
|
|
7349
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7350
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7351
|
-
0,
|
|
7352
|
-
0.0f, 0.0f,
|
|
7353
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7354
|
-
}, dryrun);
|
|
7630
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
|
7355
7631
|
}
|
|
7356
7632
|
|
|
7357
7633
|
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7358
|
-
|
|
7359
|
-
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7360
|
-
|
|
7361
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
|
|
7362
|
-
(uint32_t)ggml_nelements(src0),
|
|
7363
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7364
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7365
|
-
0,
|
|
7366
|
-
0.0f, 0.0f,
|
|
7367
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7368
|
-
}, dryrun);
|
|
7634
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
|
7369
7635
|
}
|
|
7370
7636
|
|
|
7371
7637
|
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7372
|
-
|
|
7373
|
-
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7374
|
-
|
|
7375
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
|
|
7376
|
-
(uint32_t)ggml_nelements(src0),
|
|
7377
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7378
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7379
|
-
0,
|
|
7380
|
-
0.0f, 0.0f,
|
|
7381
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7382
|
-
}, dryrun);
|
|
7638
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
|
7383
7639
|
}
|
|
7384
7640
|
|
|
7385
7641
|
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7386
|
-
|
|
7387
|
-
|
|
7388
|
-
|
|
7642
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
|
7643
|
+
p.param1 = ggml_get_op_params_f32(dst, 0);
|
|
7644
|
+
p.param2 = ggml_get_op_params_f32(dst, 1);
|
|
7389
7645
|
|
|
7390
|
-
ggml_vk_op_f32
|
|
7391
|
-
(uint32_t)ggml_nelements(src0),
|
|
7392
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7393
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7394
|
-
0,
|
|
7395
|
-
op_params[0], op_params[1],
|
|
7396
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7397
|
-
}, dryrun);
|
|
7646
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
|
|
7398
7647
|
}
|
|
7399
7648
|
|
|
7400
7649
|
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7401
|
-
|
|
7402
|
-
|
|
7650
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
|
7651
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
|
|
7652
|
+
}
|
|
7403
7653
|
|
|
7404
|
-
|
|
7405
|
-
|
|
7406
|
-
|
|
7407
|
-
|
|
7408
|
-
|
|
7409
|
-
|
|
7410
|
-
|
|
7411
|
-
|
|
7654
|
+
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7655
|
+
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
|
|
7656
|
+
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
|
|
7657
|
+
const int32_t s2 = ggml_get_op_params_i32(dst, 2);
|
|
7658
|
+
const int32_t s3 = ggml_get_op_params_i32(dst, 3);
|
|
7659
|
+
const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
|
|
7660
|
+
const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
|
|
7661
|
+
|
|
7662
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
|
7663
|
+
memcpy(&p.param1, &s01_packed, sizeof(float));
|
|
7664
|
+
memcpy(&p.param2, &s23_packed, sizeof(float));
|
|
7665
|
+
|
|
7666
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
|
|
7412
7667
|
}
|
|
7413
7668
|
|
|
7414
7669
|
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7415
|
-
|
|
7416
|
-
|
|
7417
|
-
|
|
7418
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
|
|
7419
|
-
(uint32_t)ggml_nelements(dst),
|
|
7420
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7421
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7422
|
-
0,
|
|
7423
|
-
0.0f, 0.0f,
|
|
7424
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7425
|
-
}, dryrun);
|
|
7670
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
|
7671
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
|
|
7426
7672
|
}
|
|
7427
7673
|
|
|
7428
7674
|
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7429
|
-
|
|
7430
|
-
|
|
7431
|
-
|
|
7432
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
|
|
7433
|
-
(uint32_t)ggml_nelements(dst),
|
|
7434
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7435
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7436
|
-
0,
|
|
7437
|
-
0.0f, 0.0f,
|
|
7438
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7439
|
-
}, dryrun);
|
|
7675
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
|
7676
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
|
|
7440
7677
|
}
|
|
7441
7678
|
|
|
7442
7679
|
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7443
|
-
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7444
|
-
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7445
|
-
|
|
7446
7680
|
uint32_t ne = (uint32_t)ggml_nelements(src0);
|
|
7447
7681
|
if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
|
7448
7682
|
// Convert from number of logical elements to 2- or 4-byte units.
|
|
@@ -7454,13 +7688,22 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
7454
7688
|
}
|
|
7455
7689
|
}
|
|
7456
7690
|
|
|
7457
|
-
|
|
7458
|
-
|
|
7459
|
-
|
|
7460
|
-
|
|
7691
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
|
|
7692
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
|
|
7693
|
+
}
|
|
7694
|
+
|
|
7695
|
+
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
7696
|
+
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7697
|
+
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
7698
|
+
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7699
|
+
|
|
7700
|
+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
|
|
7701
|
+
(uint32_t)ggml_nelements(src0),
|
|
7702
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7703
|
+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
|
7704
|
+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7461
7705
|
0,
|
|
7462
|
-
0.0f, 0.0f,
|
|
7463
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7706
|
+
0.0f, 0.0f, 0,
|
|
7464
7707
|
}, dryrun);
|
|
7465
7708
|
}
|
|
7466
7709
|
|
|
@@ -7485,18 +7728,18 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
7485
7728
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
|
|
7486
7729
|
}
|
|
7487
7730
|
|
|
7488
|
-
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7489
|
-
float * op_params = (float *)dst->op_params;
|
|
7731
|
+
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
|
|
7490
7732
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7733
|
+
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
7491
7734
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7492
7735
|
|
|
7493
|
-
ggml_vk_op_f32<
|
|
7736
|
+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
|
|
7494
7737
|
(uint32_t)ggml_nelements(src0),
|
|
7495
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
|
7496
|
-
(uint32_t)
|
|
7738
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7739
|
+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
|
7740
|
+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7497
7741
|
0,
|
|
7498
|
-
op_params[0], 0.0f,
|
|
7499
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7742
|
+
op_params[0], 0.0f, 0,
|
|
7500
7743
|
}, dryrun);
|
|
7501
7744
|
}
|
|
7502
7745
|
|
|
@@ -7514,6 +7757,25 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
|
|
|
7514
7757
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
|
7515
7758
|
}
|
|
7516
7759
|
|
|
7760
|
+
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
7761
|
+
const bool swapped = (bool)dst->op_params[1];
|
|
7762
|
+
const bool split = src1 != nullptr;
|
|
7763
|
+
|
|
7764
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
7765
|
+
|
|
7766
|
+
if (!split) {
|
|
7767
|
+
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
|
|
7768
|
+
} else {
|
|
7769
|
+
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
|
|
7770
|
+
GGML_ASSERT(src0->ne[0] == dst->ne[0]);
|
|
7771
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
7772
|
+
}
|
|
7773
|
+
|
|
7774
|
+
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
|
|
7775
|
+
|
|
7776
|
+
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
|
|
7777
|
+
}
|
|
7778
|
+
|
|
7517
7779
|
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7518
7780
|
int32_t * op_params = (int32_t *)dst->op_params;
|
|
7519
7781
|
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
|
|
@@ -7529,7 +7791,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
7529
7791
|
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
|
|
7530
7792
|
const uint32_t nrows_y = (uint32_t)src0->ne[1];
|
|
7531
7793
|
|
|
7532
|
-
const uint32_t
|
|
7794
|
+
const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
|
|
7795
|
+
const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
|
|
7796
|
+
const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
|
|
7797
|
+
const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
|
|
7798
|
+
const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
|
|
7799
|
+
|
|
7800
|
+
const uint32_t n_head_kv = src0->ne[2];
|
|
7533
7801
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
|
7534
7802
|
|
|
7535
7803
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
@@ -7538,6 +7806,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
7538
7806
|
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
|
|
7539
7807
|
ncols,
|
|
7540
7808
|
src1 != nullptr ? nrows_y : (uint32_t)0,
|
|
7809
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
|
7810
|
+
ne12, ne13,
|
|
7811
|
+
nb11, nb12, nb13,
|
|
7541
7812
|
scale, max_bias,
|
|
7542
7813
|
m0, m1,
|
|
7543
7814
|
n_head_log2,
|
|
@@ -8687,11 +8958,12 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|
|
8687
8958
|
}
|
|
8688
8959
|
}
|
|
8689
8960
|
|
|
8690
|
-
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
|
|
8961
|
+
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
|
|
8691
8962
|
|
|
8692
8963
|
// Returns true if node has enqueued work into the queue, false otherwise
|
|
8693
8964
|
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
|
8694
|
-
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx,
|
|
8965
|
+
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
|
|
8966
|
+
ggml_tensor * node = cgraph->nodes[node_idx];
|
|
8695
8967
|
if (ggml_is_empty(node) || !node->buffer) {
|
|
8696
8968
|
return false;
|
|
8697
8969
|
}
|
|
@@ -8716,6 +8988,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8716
8988
|
switch (ggml_get_unary_op(node)) {
|
|
8717
8989
|
case GGML_UNARY_OP_SILU:
|
|
8718
8990
|
case GGML_UNARY_OP_GELU:
|
|
8991
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
8719
8992
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
8720
8993
|
case GGML_UNARY_OP_RELU:
|
|
8721
8994
|
case GGML_UNARY_OP_TANH:
|
|
@@ -8725,6 +8998,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8725
8998
|
return false;
|
|
8726
8999
|
}
|
|
8727
9000
|
break;
|
|
9001
|
+
case GGML_OP_GLU:
|
|
9002
|
+
switch (ggml_get_glu_op(node)) {
|
|
9003
|
+
case GGML_GLU_OP_GEGLU:
|
|
9004
|
+
case GGML_GLU_OP_REGLU:
|
|
9005
|
+
case GGML_GLU_OP_SWIGLU:
|
|
9006
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9007
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9008
|
+
break;
|
|
9009
|
+
default:
|
|
9010
|
+
return false;
|
|
9011
|
+
}
|
|
9012
|
+
break;
|
|
8728
9013
|
case GGML_OP_REPEAT:
|
|
8729
9014
|
case GGML_OP_REPEAT_BACK:
|
|
8730
9015
|
case GGML_OP_GET_ROWS:
|
|
@@ -8741,7 +9026,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8741
9026
|
case GGML_OP_COS:
|
|
8742
9027
|
case GGML_OP_CLAMP:
|
|
8743
9028
|
case GGML_OP_PAD:
|
|
9029
|
+
case GGML_OP_ROLL:
|
|
8744
9030
|
case GGML_OP_CPY:
|
|
9031
|
+
case GGML_OP_SET_ROWS:
|
|
8745
9032
|
case GGML_OP_CONT:
|
|
8746
9033
|
case GGML_OP_DUP:
|
|
8747
9034
|
case GGML_OP_SILU_BACK:
|
|
@@ -8808,6 +9095,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8808
9095
|
case GGML_OP_CLAMP:
|
|
8809
9096
|
case GGML_OP_PAD:
|
|
8810
9097
|
case GGML_OP_CPY:
|
|
9098
|
+
case GGML_OP_SET_ROWS:
|
|
8811
9099
|
case GGML_OP_CONT:
|
|
8812
9100
|
case GGML_OP_DUP:
|
|
8813
9101
|
case GGML_OP_SILU_BACK:
|
|
@@ -8817,6 +9105,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8817
9105
|
case GGML_OP_RMS_NORM_BACK:
|
|
8818
9106
|
case GGML_OP_L2_NORM:
|
|
8819
9107
|
case GGML_OP_UNARY:
|
|
9108
|
+
case GGML_OP_GLU:
|
|
8820
9109
|
case GGML_OP_DIAG_MASK_INF:
|
|
8821
9110
|
case GGML_OP_SOFT_MAX:
|
|
8822
9111
|
case GGML_OP_SOFT_MAX_BACK:
|
|
@@ -8909,12 +9198,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8909
9198
|
case GGML_OP_PAD:
|
|
8910
9199
|
ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
|
|
8911
9200
|
|
|
9201
|
+
break;
|
|
9202
|
+
case GGML_OP_ROLL:
|
|
9203
|
+
ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
|
|
9204
|
+
|
|
8912
9205
|
break;
|
|
8913
9206
|
case GGML_OP_CPY:
|
|
8914
9207
|
case GGML_OP_CONT:
|
|
8915
9208
|
case GGML_OP_DUP:
|
|
8916
9209
|
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
|
|
8917
9210
|
|
|
9211
|
+
break;
|
|
9212
|
+
case GGML_OP_SET_ROWS:
|
|
9213
|
+
ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
9214
|
+
|
|
8918
9215
|
break;
|
|
8919
9216
|
case GGML_OP_SILU_BACK:
|
|
8920
9217
|
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -8929,8 +9226,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8929
9226
|
|
|
8930
9227
|
break;
|
|
8931
9228
|
case GGML_OP_RMS_NORM:
|
|
8932
|
-
|
|
8933
|
-
|
|
9229
|
+
if (ctx->num_additional_fused_ops > 0) {
|
|
9230
|
+
// fused rms_norm + mul
|
|
9231
|
+
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
|
9232
|
+
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
|
|
9233
|
+
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
|
|
9234
|
+
} else {
|
|
9235
|
+
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
|
|
9236
|
+
}
|
|
8934
9237
|
break;
|
|
8935
9238
|
case GGML_OP_RMS_NORM_BACK:
|
|
8936
9239
|
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -8944,6 +9247,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8944
9247
|
switch (ggml_get_unary_op(node)) {
|
|
8945
9248
|
case GGML_UNARY_OP_SILU:
|
|
8946
9249
|
case GGML_UNARY_OP_GELU:
|
|
9250
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
8947
9251
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
8948
9252
|
case GGML_UNARY_OP_RELU:
|
|
8949
9253
|
case GGML_UNARY_OP_TANH:
|
|
@@ -8954,6 +9258,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8954
9258
|
return false;
|
|
8955
9259
|
}
|
|
8956
9260
|
break;
|
|
9261
|
+
case GGML_OP_GLU:
|
|
9262
|
+
switch (ggml_get_glu_op(node)) {
|
|
9263
|
+
case GGML_GLU_OP_GEGLU:
|
|
9264
|
+
case GGML_GLU_OP_REGLU:
|
|
9265
|
+
case GGML_GLU_OP_SWIGLU:
|
|
9266
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9267
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9268
|
+
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
9269
|
+
break;
|
|
9270
|
+
default:
|
|
9271
|
+
return false;
|
|
9272
|
+
}
|
|
9273
|
+
break;
|
|
8957
9274
|
case GGML_OP_DIAG_MASK_INF:
|
|
8958
9275
|
ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
|
|
8959
9276
|
|
|
@@ -9075,12 +9392,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
9075
9392
|
|
|
9076
9393
|
ctx->compute_ctx.reset();
|
|
9077
9394
|
|
|
9078
|
-
bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
|
|
9395
|
+
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
|
|
9079
9396
|
if (!ok) {
|
|
9080
9397
|
if (node->op == GGML_OP_UNARY) {
|
|
9081
9398
|
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
|
|
9082
|
-
}
|
|
9083
|
-
|
|
9399
|
+
} else if (node->op == GGML_OP_GLU) {
|
|
9400
|
+
std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
|
|
9401
|
+
} else {
|
|
9084
9402
|
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
|
|
9085
9403
|
}
|
|
9086
9404
|
}
|
|
@@ -9089,7 +9407,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
9089
9407
|
return true;
|
|
9090
9408
|
}
|
|
9091
9409
|
|
|
9092
|
-
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
|
|
9410
|
+
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
|
|
9411
|
+
GGML_UNUSED(cgraph);
|
|
9093
9412
|
ggml_backend_buffer * buf = nullptr;
|
|
9094
9413
|
|
|
9095
9414
|
switch (tensor->op) {
|
|
@@ -9107,7 +9426,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9107
9426
|
case GGML_OP_COS:
|
|
9108
9427
|
case GGML_OP_CLAMP:
|
|
9109
9428
|
case GGML_OP_PAD:
|
|
9429
|
+
case GGML_OP_ROLL:
|
|
9110
9430
|
case GGML_OP_CPY:
|
|
9431
|
+
case GGML_OP_SET_ROWS:
|
|
9111
9432
|
case GGML_OP_CONT:
|
|
9112
9433
|
case GGML_OP_DUP:
|
|
9113
9434
|
case GGML_OP_SILU_BACK:
|
|
@@ -9149,6 +9470,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9149
9470
|
switch (ggml_get_unary_op(tensor)) {
|
|
9150
9471
|
case GGML_UNARY_OP_SILU:
|
|
9151
9472
|
case GGML_UNARY_OP_GELU:
|
|
9473
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
9152
9474
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
9153
9475
|
case GGML_UNARY_OP_RELU:
|
|
9154
9476
|
case GGML_UNARY_OP_TANH:
|
|
@@ -9159,6 +9481,19 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9159
9481
|
return false;
|
|
9160
9482
|
}
|
|
9161
9483
|
break;
|
|
9484
|
+
case GGML_OP_GLU:
|
|
9485
|
+
switch (ggml_get_glu_op(tensor)) {
|
|
9486
|
+
case GGML_GLU_OP_GEGLU:
|
|
9487
|
+
case GGML_GLU_OP_REGLU:
|
|
9488
|
+
case GGML_GLU_OP_SWIGLU:
|
|
9489
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9490
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9491
|
+
buf = tensor->buffer;
|
|
9492
|
+
break;
|
|
9493
|
+
default:
|
|
9494
|
+
return false;
|
|
9495
|
+
}
|
|
9496
|
+
break;
|
|
9162
9497
|
case GGML_OP_MUL_MAT:
|
|
9163
9498
|
case GGML_OP_MUL_MAT_ID:
|
|
9164
9499
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
@@ -9185,7 +9520,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9185
9520
|
// Only run if ctx hasn't been submitted yet
|
|
9186
9521
|
if (!subctx->seqs.empty()) {
|
|
9187
9522
|
#ifdef GGML_VULKAN_CHECK_RESULTS
|
|
9188
|
-
ggml_vk_check_results_0(
|
|
9523
|
+
ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
|
|
9189
9524
|
use_fence = true;
|
|
9190
9525
|
#endif
|
|
9191
9526
|
|
|
@@ -9205,7 +9540,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9205
9540
|
ggml_vk_wait_for_fence(ctx);
|
|
9206
9541
|
}
|
|
9207
9542
|
#ifdef GGML_VULKAN_CHECK_RESULTS
|
|
9208
|
-
ggml_vk_check_results_1(
|
|
9543
|
+
ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
|
|
9209
9544
|
#endif
|
|
9210
9545
|
}
|
|
9211
9546
|
|
|
@@ -9652,16 +9987,59 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
|
|
|
9652
9987
|
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
|
|
9653
9988
|
}
|
|
9654
9989
|
|
|
9990
|
+
static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
|
|
9991
|
+
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
|
9992
|
+
return false;
|
|
9993
|
+
}
|
|
9994
|
+
|
|
9995
|
+
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
|
|
9996
|
+
// additional constraints specific to this fusion
|
|
9997
|
+
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
|
|
9998
|
+
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
|
9999
|
+
|
|
10000
|
+
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
|
|
10001
|
+
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
|
|
10002
|
+
// rms_norm only supports f32
|
|
10003
|
+
if (mul->src[0]->type != GGML_TYPE_F32 ||
|
|
10004
|
+
mul->src[1]->type != GGML_TYPE_F32 ||
|
|
10005
|
+
mul->type != GGML_TYPE_F32) {
|
|
10006
|
+
return false;
|
|
10007
|
+
}
|
|
10008
|
+
// if rms_norm is the B operand, then we don't handle broadcast
|
|
10009
|
+
if (rms_norm == mul->src[1] &&
|
|
10010
|
+
mul->src[0]->ne[1] != rms_norm->ne[1]) {
|
|
10011
|
+
return false;
|
|
10012
|
+
}
|
|
10013
|
+
// rms_norm shader assumes contiguous rows
|
|
10014
|
+
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
|
|
10015
|
+
return false;
|
|
10016
|
+
}
|
|
10017
|
+
}
|
|
10018
|
+
return true;
|
|
10019
|
+
}
|
|
10020
|
+
|
|
9655
10021
|
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
|
9656
10022
|
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
|
9657
10023
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
9658
10024
|
|
|
10025
|
+
if (vk_instance.debug_utils_support) {
|
|
10026
|
+
vk::DebugUtilsLabelEXT dul = {};
|
|
10027
|
+
dul.pLabelName = "ggml_backend_vk_graph_compute";
|
|
10028
|
+
dul.color = std::array<float,4>{1.0f, 1.0f, 1.0f, 1.0f};
|
|
10029
|
+
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
|
|
10030
|
+
}
|
|
10031
|
+
|
|
9659
10032
|
uint64_t total_mat_mul_bytes = 0;
|
|
9660
10033
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
9661
|
-
|
|
10034
|
+
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
10035
|
+
ctx->num_additional_fused_ops = 1;
|
|
10036
|
+
}
|
|
10037
|
+
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
|
9662
10038
|
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
|
9663
10039
|
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
9664
10040
|
}
|
|
10041
|
+
i += ctx->num_additional_fused_ops;
|
|
10042
|
+
ctx->num_additional_fused_ops = 0;
|
|
9665
10043
|
}
|
|
9666
10044
|
if (ctx->device->need_compiles) {
|
|
9667
10045
|
ggml_vk_load_shaders(ctx->device);
|
|
@@ -9723,14 +10101,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9723
10101
|
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
9724
10102
|
}
|
|
9725
10103
|
|
|
10104
|
+
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
10105
|
+
ctx->num_additional_fused_ops = 1;
|
|
10106
|
+
}
|
|
10107
|
+
|
|
9726
10108
|
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
|
|
9727
10109
|
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
|
9728
10110
|
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
|
9729
10111
|
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
|
9730
|
-
(i == last_node) ||
|
|
10112
|
+
(i + ctx->num_additional_fused_ops == last_node) ||
|
|
9731
10113
|
(almost_ready && !ctx->almost_ready_fence_pending);
|
|
9732
10114
|
|
|
9733
|
-
bool enqueued = ggml_vk_build_graph(ctx, cgraph
|
|
10115
|
+
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
|
|
9734
10116
|
|
|
9735
10117
|
if (vk_perf_logger_enabled) {
|
|
9736
10118
|
if (ctx->compute_ctx.expired()) {
|
|
@@ -9740,7 +10122,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9740
10122
|
} else {
|
|
9741
10123
|
compute_ctx = ctx->compute_ctx.lock();
|
|
9742
10124
|
}
|
|
9743
|
-
|
|
10125
|
+
// If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
|
|
10126
|
+
for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
|
|
10127
|
+
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
|
|
10128
|
+
}
|
|
9744
10129
|
}
|
|
9745
10130
|
|
|
9746
10131
|
if (enqueued) {
|
|
@@ -9762,6 +10147,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9762
10147
|
}
|
|
9763
10148
|
submit_count++;
|
|
9764
10149
|
}
|
|
10150
|
+
i += ctx->num_additional_fused_ops;
|
|
10151
|
+
ctx->num_additional_fused_ops = 0;
|
|
9765
10152
|
}
|
|
9766
10153
|
|
|
9767
10154
|
if (vk_perf_logger_enabled) {
|
|
@@ -9923,6 +10310,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9923
10310
|
case GGML_OP_UNARY:
|
|
9924
10311
|
switch (ggml_get_unary_op(op)) {
|
|
9925
10312
|
case GGML_UNARY_OP_GELU:
|
|
10313
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
9926
10314
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
9927
10315
|
case GGML_UNARY_OP_SILU:
|
|
9928
10316
|
case GGML_UNARY_OP_RELU:
|
|
@@ -9936,15 +10324,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9936
10324
|
return false;
|
|
9937
10325
|
}
|
|
9938
10326
|
break;
|
|
10327
|
+
case GGML_OP_GLU:
|
|
10328
|
+
switch (ggml_get_glu_op(op)) {
|
|
10329
|
+
case GGML_GLU_OP_GEGLU:
|
|
10330
|
+
case GGML_GLU_OP_REGLU:
|
|
10331
|
+
case GGML_GLU_OP_SWIGLU:
|
|
10332
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
10333
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
10334
|
+
return ggml_is_contiguous(op->src[0]) &&
|
|
10335
|
+
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
|
10336
|
+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
|
10337
|
+
(op->src[0]->type == op->type);
|
|
10338
|
+
default:
|
|
10339
|
+
return false;
|
|
10340
|
+
}
|
|
10341
|
+
break;
|
|
9939
10342
|
case GGML_OP_MUL_MAT:
|
|
9940
10343
|
case GGML_OP_MUL_MAT_ID:
|
|
9941
10344
|
{
|
|
9942
10345
|
ggml_type src0_type = op->src[0]->type;
|
|
9943
10346
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
9944
10347
|
const vk_device& device = ggml_vk_get_device(ctx->device);
|
|
9945
|
-
if (op->op == GGML_OP_MUL_MAT_ID
|
|
9946
|
-
|
|
9947
|
-
|
|
10348
|
+
if (op->op == GGML_OP_MUL_MAT_ID) {
|
|
10349
|
+
if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
|
|
10350
|
+
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
|
10351
|
+
return false;
|
|
10352
|
+
}
|
|
10353
|
+
// Check against size of shared memory variable
|
|
10354
|
+
if (op->src[2]->ne[0] > 4096) {
|
|
10355
|
+
return false;
|
|
10356
|
+
}
|
|
9948
10357
|
}
|
|
9949
10358
|
switch (src0_type) {
|
|
9950
10359
|
case GGML_TYPE_F32:
|
|
@@ -10002,19 +10411,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10002
10411
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
10003
10412
|
auto device = ggml_vk_get_device(ctx->device);
|
|
10004
10413
|
bool coopmat2 = device->coopmat2;
|
|
10005
|
-
|
|
10006
|
-
|
|
10007
|
-
case 80:
|
|
10008
|
-
case 96:
|
|
10009
|
-
case 112:
|
|
10010
|
-
case 128:
|
|
10011
|
-
case 256:
|
|
10012
|
-
break;
|
|
10013
|
-
default:
|
|
10014
|
-
return false;
|
|
10015
|
-
}
|
|
10016
|
-
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
|
10017
|
-
// different head sizes of K and V are not supported yet
|
|
10414
|
+
FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
|
|
10415
|
+
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
|
|
10018
10416
|
return false;
|
|
10019
10417
|
}
|
|
10020
10418
|
if (op->src[0]->type != GGML_TYPE_F32) {
|
|
@@ -10094,6 +10492,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10094
10492
|
return false;
|
|
10095
10493
|
}
|
|
10096
10494
|
} break;
|
|
10495
|
+
case GGML_OP_SET_ROWS:
|
|
10496
|
+
{
|
|
10497
|
+
switch (op->type) {
|
|
10498
|
+
case GGML_TYPE_F32:
|
|
10499
|
+
case GGML_TYPE_F16:
|
|
10500
|
+
case GGML_TYPE_BF16:
|
|
10501
|
+
case GGML_TYPE_Q4_0:
|
|
10502
|
+
case GGML_TYPE_Q4_1:
|
|
10503
|
+
case GGML_TYPE_Q5_0:
|
|
10504
|
+
case GGML_TYPE_Q5_1:
|
|
10505
|
+
case GGML_TYPE_Q8_0:
|
|
10506
|
+
case GGML_TYPE_IQ4_NL:
|
|
10507
|
+
return true;
|
|
10508
|
+
default:
|
|
10509
|
+
return false;
|
|
10510
|
+
}
|
|
10511
|
+
} break;
|
|
10097
10512
|
case GGML_OP_CONT:
|
|
10098
10513
|
case GGML_OP_CPY:
|
|
10099
10514
|
case GGML_OP_DUP:
|
|
@@ -10178,11 +10593,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10178
10593
|
case GGML_OP_CLAMP:
|
|
10179
10594
|
return op->src[0]->type == GGML_TYPE_F32;
|
|
10180
10595
|
case GGML_OP_UPSCALE:
|
|
10181
|
-
return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
|
10182
10596
|
case GGML_OP_ACC:
|
|
10183
10597
|
case GGML_OP_CONCAT:
|
|
10184
10598
|
case GGML_OP_SCALE:
|
|
10185
10599
|
case GGML_OP_PAD:
|
|
10600
|
+
case GGML_OP_ROLL:
|
|
10186
10601
|
case GGML_OP_DIAG_MASK_INF:
|
|
10187
10602
|
case GGML_OP_SOFT_MAX:
|
|
10188
10603
|
case GGML_OP_SOFT_MAX_BACK:
|
|
@@ -10345,6 +10760,22 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
|
|
|
10345
10760
|
UNUSED(instance_extensions);
|
|
10346
10761
|
}
|
|
10347
10762
|
|
|
10763
|
+
// Extension availability
|
|
10764
|
+
static bool ggml_vk_instance_debug_utils_ext_available(
|
|
10765
|
+
const std::vector<vk::ExtensionProperties> & instance_extensions) {
|
|
10766
|
+
// Check for portability enumeration extension for MoltenVK support
|
|
10767
|
+
for (const auto & properties : instance_extensions) {
|
|
10768
|
+
if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) {
|
|
10769
|
+
return true;
|
|
10770
|
+
}
|
|
10771
|
+
}
|
|
10772
|
+
|
|
10773
|
+
std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl;
|
|
10774
|
+
return false;
|
|
10775
|
+
|
|
10776
|
+
UNUSED(instance_extensions);
|
|
10777
|
+
}
|
|
10778
|
+
|
|
10348
10779
|
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
|
10349
10780
|
switch (props.vendorID) {
|
|
10350
10781
|
case VK_VENDOR_ID_INTEL:
|
|
@@ -10457,11 +10888,21 @@ void * comp_result;
|
|
|
10457
10888
|
size_t comp_size;
|
|
10458
10889
|
size_t comp_nb[GGML_MAX_DIMS];
|
|
10459
10890
|
size_t check_counter = 0;
|
|
10460
|
-
static void ggml_vk_check_results_0(
|
|
10891
|
+
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
|
|
10892
|
+
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
|
|
10461
10893
|
if (tensor->op == GGML_OP_TRANSPOSE) {
|
|
10462
10894
|
return;
|
|
10463
10895
|
}
|
|
10464
10896
|
|
|
10897
|
+
bool fused_rms_norm_mul = false;
|
|
10898
|
+
int rms_norm_idx = -1;
|
|
10899
|
+
if (ctx->num_additional_fused_ops == 1 &&
|
|
10900
|
+
tensor->op == GGML_OP_RMS_NORM &&
|
|
10901
|
+
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
|
|
10902
|
+
fused_rms_norm_mul = true;
|
|
10903
|
+
tensor = cgraph->nodes[tensor_idx + 1];
|
|
10904
|
+
}
|
|
10905
|
+
|
|
10465
10906
|
check_counter++;
|
|
10466
10907
|
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
|
|
10467
10908
|
return;
|
|
@@ -10489,6 +10930,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10489
10930
|
|
|
10490
10931
|
for (int i = 0; i < 6; i++) {
|
|
10491
10932
|
ggml_tensor * srci = tensor->src[i];
|
|
10933
|
+
if (fused_rms_norm_mul) {
|
|
10934
|
+
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
|
|
10935
|
+
ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
|
|
10936
|
+
switch (i) {
|
|
10937
|
+
case 0: srci = rms_norm->src[0]; break;
|
|
10938
|
+
case 1: srci = tensor->src[1 - rms_norm_idx]; break;
|
|
10939
|
+
default: continue;
|
|
10940
|
+
}
|
|
10941
|
+
}
|
|
10492
10942
|
if (srci == nullptr) {
|
|
10493
10943
|
continue;
|
|
10494
10944
|
}
|
|
@@ -10546,7 +10996,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10546
10996
|
} else if (tensor->op == GGML_OP_SUB) {
|
|
10547
10997
|
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
|
|
10548
10998
|
} else if (tensor->op == GGML_OP_MUL) {
|
|
10549
|
-
|
|
10999
|
+
if (fused_rms_norm_mul) {
|
|
11000
|
+
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
|
|
11001
|
+
tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
|
|
11002
|
+
} else {
|
|
11003
|
+
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
|
|
11004
|
+
}
|
|
10550
11005
|
} else if (tensor->op == GGML_OP_DIV) {
|
|
10551
11006
|
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
|
|
10552
11007
|
} else if (tensor->op == GGML_OP_CONCAT) {
|
|
@@ -10634,6 +11089,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10634
11089
|
case GGML_UNARY_OP_GELU:
|
|
10635
11090
|
tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
|
|
10636
11091
|
break;
|
|
11092
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
11093
|
+
tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
|
|
11094
|
+
break;
|
|
10637
11095
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
10638
11096
|
tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
|
|
10639
11097
|
break;
|
|
@@ -10650,6 +11108,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10650
11108
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
|
10651
11109
|
GGML_ABORT("fatal error");
|
|
10652
11110
|
}
|
|
11111
|
+
} else if (tensor->op == GGML_OP_GLU) {
|
|
11112
|
+
if (src_clone[1] == nullptr) {
|
|
11113
|
+
tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
|
|
11114
|
+
} else {
|
|
11115
|
+
tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
|
|
11116
|
+
}
|
|
10653
11117
|
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
|
|
10654
11118
|
if (src1 == nullptr) {
|
|
10655
11119
|
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
|
|
@@ -10657,6 +11121,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10657
11121
|
} else {
|
|
10658
11122
|
tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
|
|
10659
11123
|
}
|
|
11124
|
+
} else if (tensor->op == GGML_OP_SET_ROWS) {
|
|
11125
|
+
tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
|
|
10660
11126
|
} else if (tensor->op == GGML_OP_CONT) {
|
|
10661
11127
|
tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
|
10662
11128
|
} else if (tensor->op == GGML_OP_RESHAPE) {
|
|
@@ -10728,10 +11194,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10728
11194
|
GGML_ABORT("fatal error");
|
|
10729
11195
|
}
|
|
10730
11196
|
|
|
10731
|
-
ggml_cgraph *
|
|
10732
|
-
ggml_build_forward_expand(
|
|
11197
|
+
ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
|
|
11198
|
+
ggml_build_forward_expand(cgraph_cpu, tensor_clone);
|
|
10733
11199
|
|
|
10734
|
-
ggml_graph_compute_with_ctx(ggml_ctx,
|
|
11200
|
+
ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
|
|
10735
11201
|
|
|
10736
11202
|
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
|
|
10737
11203
|
ggml_vk_print_tensor(tensor_clone, "tensor_clone");
|
|
@@ -10754,10 +11220,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10754
11220
|
VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
|
|
10755
11221
|
}
|
|
10756
11222
|
|
|
10757
|
-
static void ggml_vk_check_results_1(
|
|
11223
|
+
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
|
|
11224
|
+
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
|
|
10758
11225
|
if (tensor->op == GGML_OP_TRANSPOSE) {
|
|
10759
11226
|
return;
|
|
10760
11227
|
}
|
|
11228
|
+
bool fused_rms_norm_mul = false;
|
|
11229
|
+
if (ctx->num_additional_fused_ops == 1 &&
|
|
11230
|
+
tensor->op == GGML_OP_RMS_NORM &&
|
|
11231
|
+
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
|
|
11232
|
+
fused_rms_norm_mul = true;
|
|
11233
|
+
tensor = cgraph->nodes[tensor_idx + 1];
|
|
11234
|
+
}
|
|
11235
|
+
|
|
10761
11236
|
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
|
|
10762
11237
|
return;
|
|
10763
11238
|
}
|