@novastera-oss/llamarn 0.2.9 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +17 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.h +4 -0
- package/cpp/llama.cpp/convert_hf_to_gguf.py +745 -6
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +7 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1203 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +185 -79
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +64 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +35 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +167 -39
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +254 -57
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +505 -40
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +60 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +711 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +382 -61
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +209 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +73 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +0 -40
- package/cpp/llama.cpp/src/llama-arch.cpp +210 -3
- package/cpp/llama.cpp/src/llama-arch.h +18 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +27 -1
- package/cpp/llama.cpp/src/llama-batch.h +8 -1
- package/cpp/llama.cpp/src/llama-chat.cpp +15 -0
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-graph.cpp +119 -184
- package/cpp/llama.cpp/src/llama-graph.h +47 -60
- package/cpp/llama.cpp/src/llama-hparams.cpp +7 -1
- package/cpp/llama.cpp/src/llama-hparams.h +3 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +62 -24
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +20 -10
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +2530 -685
- package/cpp/llama.cpp/src/llama-model.h +18 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +1 -0
- package/cpp/llama.cpp/src/llama-vocab.cpp +13 -2
- package/cpp/llama.cpp/src/llama-vocab.h +41 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +4 -0
- package/ios/include/llama.h +0 -40
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5055 -4886
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5091 -4922
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -224,6 +224,21 @@ enum vk_device_architecture {
|
|
|
224
224
|
INTEL_XE2,
|
|
225
225
|
};
|
|
226
226
|
|
|
227
|
+
// HSK x HSV
|
|
228
|
+
enum FaHeadSizes {
|
|
229
|
+
FA_HEAD_SIZE_64,
|
|
230
|
+
FA_HEAD_SIZE_80,
|
|
231
|
+
FA_HEAD_SIZE_96,
|
|
232
|
+
FA_HEAD_SIZE_112,
|
|
233
|
+
FA_HEAD_SIZE_128,
|
|
234
|
+
FA_HEAD_SIZE_192,
|
|
235
|
+
FA_HEAD_SIZE_192_128,
|
|
236
|
+
FA_HEAD_SIZE_256,
|
|
237
|
+
FA_HEAD_SIZE_576_512,
|
|
238
|
+
FA_HEAD_SIZE_UNSUPPORTED,
|
|
239
|
+
FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
|
|
240
|
+
};
|
|
241
|
+
|
|
227
242
|
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
|
228
243
|
vk::PhysicalDeviceProperties props = device.getProperties();
|
|
229
244
|
|
|
@@ -305,7 +320,7 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
|
|
|
305
320
|
}
|
|
306
321
|
|
|
307
322
|
struct vk_device_struct {
|
|
308
|
-
std::
|
|
323
|
+
std::recursive_mutex mutex;
|
|
309
324
|
|
|
310
325
|
vk::PhysicalDevice physical_device;
|
|
311
326
|
vk::PhysicalDeviceProperties properties;
|
|
@@ -410,32 +425,42 @@ struct vk_device_struct {
|
|
|
410
425
|
vk_pipeline pipeline_div_norepeat[2][2][2];
|
|
411
426
|
|
|
412
427
|
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
|
413
|
-
vk_pipeline
|
|
428
|
+
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
|
|
414
429
|
vk_pipeline pipeline_scale_f32;
|
|
415
430
|
vk_pipeline pipeline_sqr_f32;
|
|
416
431
|
vk_pipeline pipeline_sin_f32;
|
|
417
432
|
vk_pipeline pipeline_cos_f32;
|
|
418
433
|
vk_pipeline pipeline_clamp_f32;
|
|
419
434
|
vk_pipeline pipeline_pad_f32;
|
|
435
|
+
vk_pipeline pipeline_roll_f32;
|
|
420
436
|
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
|
421
437
|
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
|
|
422
438
|
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
|
|
423
439
|
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
|
424
440
|
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
|
441
|
+
vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
|
|
425
442
|
vk_pipeline pipeline_norm_f32;
|
|
426
443
|
vk_pipeline pipeline_group_norm_f32;
|
|
427
444
|
vk_pipeline pipeline_rms_norm_f32;
|
|
445
|
+
vk_pipeline pipeline_rms_norm_mul_f32;
|
|
428
446
|
vk_pipeline pipeline_rms_norm_back_f32;
|
|
429
447
|
vk_pipeline pipeline_l2_norm_f32;
|
|
430
448
|
|
|
431
449
|
// [src/dst 0=fp32,1=fp16]
|
|
432
450
|
vk_pipeline pipeline_gelu[2];
|
|
451
|
+
vk_pipeline pipeline_gelu_erf[2];
|
|
433
452
|
vk_pipeline pipeline_gelu_quick[2];
|
|
434
453
|
vk_pipeline pipeline_silu[2];
|
|
435
454
|
vk_pipeline pipeline_relu[2];
|
|
436
455
|
vk_pipeline pipeline_tanh[2];
|
|
437
456
|
vk_pipeline pipeline_sigmoid[2];
|
|
438
457
|
|
|
458
|
+
vk_pipeline pipeline_geglu[2];
|
|
459
|
+
vk_pipeline pipeline_reglu[2];
|
|
460
|
+
vk_pipeline pipeline_swiglu[2];
|
|
461
|
+
vk_pipeline pipeline_geglu_erf[2];
|
|
462
|
+
vk_pipeline pipeline_geglu_quick[2];
|
|
463
|
+
|
|
439
464
|
vk_pipeline pipeline_leaky_relu_f32;
|
|
440
465
|
vk_pipeline pipeline_silu_back_f32;
|
|
441
466
|
vk_pipeline pipeline_diag_mask_inf_f32;
|
|
@@ -461,26 +486,11 @@ struct vk_device_struct {
|
|
|
461
486
|
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
|
462
487
|
|
|
463
488
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
|
464
|
-
vk_pipeline
|
|
465
|
-
|
|
466
|
-
vk_pipeline
|
|
467
|
-
|
|
468
|
-
vk_pipeline
|
|
469
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
470
|
-
|
|
471
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
472
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
473
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
474
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
475
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
476
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
477
|
-
|
|
478
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
|
479
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
|
480
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
|
481
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
|
482
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
|
483
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
|
489
|
+
vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
|
490
|
+
|
|
491
|
+
vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
|
492
|
+
|
|
493
|
+
vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
|
484
494
|
|
|
485
495
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
|
486
496
|
|
|
@@ -493,6 +503,8 @@ struct vk_device_struct {
|
|
|
493
503
|
|
|
494
504
|
ggml_backend_buffer_type buffer_type;
|
|
495
505
|
|
|
506
|
+
bool disable_fusion;
|
|
507
|
+
|
|
496
508
|
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
|
497
509
|
std::unique_ptr<vk_memory_logger> memory_logger;
|
|
498
510
|
#endif
|
|
@@ -627,6 +639,8 @@ struct vk_flash_attn_push_constants {
|
|
|
627
639
|
uint32_t nev2;
|
|
628
640
|
uint32_t nev3;
|
|
629
641
|
uint32_t nem1;
|
|
642
|
+
uint32_t nem2;
|
|
643
|
+
uint32_t nem3;
|
|
630
644
|
|
|
631
645
|
uint32_t nb01;
|
|
632
646
|
uint32_t nb02;
|
|
@@ -637,14 +651,12 @@ struct vk_flash_attn_push_constants {
|
|
|
637
651
|
uint32_t nb21;
|
|
638
652
|
uint32_t nb22;
|
|
639
653
|
uint32_t nb23;
|
|
640
|
-
uint32_t nb31;
|
|
641
654
|
|
|
642
655
|
float scale;
|
|
643
656
|
float max_bias;
|
|
644
657
|
float logit_softcap;
|
|
645
658
|
|
|
646
|
-
uint32_t
|
|
647
|
-
uint32_t n_head_log2;
|
|
659
|
+
uint32_t mask_n_head_log2;
|
|
648
660
|
float m0;
|
|
649
661
|
float m1;
|
|
650
662
|
|
|
@@ -652,6 +664,7 @@ struct vk_flash_attn_push_constants {
|
|
|
652
664
|
uint32_t split_kv;
|
|
653
665
|
uint32_t k_num;
|
|
654
666
|
};
|
|
667
|
+
static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
|
|
655
668
|
|
|
656
669
|
struct vk_op_push_constants {
|
|
657
670
|
uint32_t KX;
|
|
@@ -660,6 +673,13 @@ struct vk_op_push_constants {
|
|
|
660
673
|
float param2;
|
|
661
674
|
};
|
|
662
675
|
|
|
676
|
+
struct vk_op_glu_push_constants {
|
|
677
|
+
uint32_t N;
|
|
678
|
+
uint32_t ne00;
|
|
679
|
+
uint32_t ne20;
|
|
680
|
+
uint32_t mode; // 0: default, 1: swapped, 2: split
|
|
681
|
+
};
|
|
682
|
+
|
|
663
683
|
struct vk_op_unary_push_constants {
|
|
664
684
|
uint32_t ne;
|
|
665
685
|
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
|
@@ -675,6 +695,37 @@ struct vk_op_unary_push_constants {
|
|
|
675
695
|
};
|
|
676
696
|
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
|
|
677
697
|
|
|
698
|
+
static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
|
|
699
|
+
GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
|
|
700
|
+
ne = ne != 0 ? ne : ggml_nelements(dst);
|
|
701
|
+
GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
|
|
702
|
+
|
|
703
|
+
vk_op_unary_push_constants p{};
|
|
704
|
+
p.ne = (uint32_t)ne;
|
|
705
|
+
|
|
706
|
+
size_t src0_tsize = ggml_type_size(src0->type);
|
|
707
|
+
p.ne00 = (uint32_t)src0->ne[0];
|
|
708
|
+
p.ne01 = (uint32_t)src0->ne[1];
|
|
709
|
+
p.ne02 = (uint32_t)src0->ne[2];
|
|
710
|
+
p.ne03 = (uint32_t)src0->ne[3];
|
|
711
|
+
p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
|
|
712
|
+
p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
|
|
713
|
+
p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
|
|
714
|
+
p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
|
|
715
|
+
|
|
716
|
+
size_t dst_tsize = ggml_type_size(dst->type);
|
|
717
|
+
p.ne10 = (uint32_t)dst->ne[0];
|
|
718
|
+
p.ne11 = (uint32_t)dst->ne[1];
|
|
719
|
+
p.ne12 = (uint32_t)dst->ne[2];
|
|
720
|
+
p.ne13 = (uint32_t)dst->ne[3];
|
|
721
|
+
p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
|
|
722
|
+
p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
|
|
723
|
+
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
|
|
724
|
+
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
|
|
725
|
+
|
|
726
|
+
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
|
|
727
|
+
}
|
|
728
|
+
|
|
678
729
|
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
|
679
730
|
// Precompute mp (m' in the paper) and L such that division
|
|
680
731
|
// can be computed using a multiply (high 32b of 64b result)
|
|
@@ -743,6 +794,14 @@ struct vk_op_rope_push_constants {
|
|
|
743
794
|
struct vk_op_soft_max_push_constants {
|
|
744
795
|
uint32_t KX;
|
|
745
796
|
uint32_t KY;
|
|
797
|
+
uint32_t ne00;
|
|
798
|
+
uint32_t ne01;
|
|
799
|
+
uint32_t ne02;
|
|
800
|
+
uint32_t ne12;
|
|
801
|
+
uint32_t ne13;
|
|
802
|
+
uint32_t nb11;
|
|
803
|
+
uint32_t nb12;
|
|
804
|
+
uint32_t nb13;
|
|
746
805
|
float scale;
|
|
747
806
|
float max_bias;
|
|
748
807
|
float m0;
|
|
@@ -836,6 +895,7 @@ struct vk_op_conv2d_dw_push_constants {
|
|
|
836
895
|
|
|
837
896
|
struct vk_op_upscale_push_constants {
|
|
838
897
|
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
|
|
898
|
+
uint32_t ne00; uint32_t ne01;
|
|
839
899
|
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
|
840
900
|
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
|
|
841
901
|
float sf0; float sf1; float sf2; float sf3;
|
|
@@ -978,6 +1038,10 @@ struct ggml_backend_vk_context {
|
|
|
978
1038
|
|
|
979
1039
|
vk_command_pool compute_cmd_pool;
|
|
980
1040
|
vk_command_pool transfer_cmd_pool;
|
|
1041
|
+
|
|
1042
|
+
// number of additional consecutive nodes that are being fused with the
|
|
1043
|
+
// node currently being processed
|
|
1044
|
+
int num_additional_fused_ops {};
|
|
981
1045
|
};
|
|
982
1046
|
|
|
983
1047
|
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
|
@@ -1063,8 +1127,8 @@ static size_t vk_skip_checks;
|
|
|
1063
1127
|
static size_t vk_output_tensor;
|
|
1064
1128
|
|
|
1065
1129
|
static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
|
|
1066
|
-
static void ggml_vk_check_results_0(
|
|
1067
|
-
static void ggml_vk_check_results_1(
|
|
1130
|
+
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
|
|
1131
|
+
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
|
|
1068
1132
|
#endif
|
|
1069
1133
|
|
|
1070
1134
|
typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
|
@@ -1197,7 +1261,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
|
|
1197
1261
|
}
|
|
1198
1262
|
|
|
1199
1263
|
{
|
|
1200
|
-
std::lock_guard<std::
|
|
1264
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
1201
1265
|
device->pipelines.insert({ pipeline->name, pipeline });
|
|
1202
1266
|
}
|
|
1203
1267
|
|
|
@@ -1411,7 +1475,7 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyPrope
|
|
|
1411
1475
|
|
|
1412
1476
|
static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
|
|
1413
1477
|
VK_LOG_DEBUG("ggml_vk_create_queue()");
|
|
1414
|
-
std::lock_guard<std::
|
|
1478
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
1415
1479
|
|
|
1416
1480
|
q.queue_family_index = queue_family_index;
|
|
1417
1481
|
q.transfer_only = transfer_only;
|
|
@@ -1673,10 +1737,46 @@ enum FaCodePath {
|
|
|
1673
1737
|
FA_COOPMAT2,
|
|
1674
1738
|
};
|
|
1675
1739
|
|
|
1740
|
+
static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
|
|
1741
|
+
if (hsk != 192 && hsk != 576 && hsk != hsv) {
|
|
1742
|
+
return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1743
|
+
}
|
|
1744
|
+
switch (hsk) {
|
|
1745
|
+
case 64: return FA_HEAD_SIZE_64;
|
|
1746
|
+
case 80: return FA_HEAD_SIZE_80;
|
|
1747
|
+
case 96: return FA_HEAD_SIZE_96;
|
|
1748
|
+
case 112: return FA_HEAD_SIZE_112;
|
|
1749
|
+
case 128: return FA_HEAD_SIZE_128;
|
|
1750
|
+
case 192:
|
|
1751
|
+
if (hsv == 192) {
|
|
1752
|
+
return FA_HEAD_SIZE_192;
|
|
1753
|
+
} else if (hsv == 128) {
|
|
1754
|
+
return FA_HEAD_SIZE_192_128;
|
|
1755
|
+
} else {
|
|
1756
|
+
return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1757
|
+
}
|
|
1758
|
+
case 256: return FA_HEAD_SIZE_256;
|
|
1759
|
+
case 576:
|
|
1760
|
+
if (hsv == 512) {
|
|
1761
|
+
return FA_HEAD_SIZE_576_512;
|
|
1762
|
+
} else {
|
|
1763
|
+
return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1764
|
+
}
|
|
1765
|
+
default: return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1766
|
+
}
|
|
1767
|
+
}
|
|
1768
|
+
|
|
1676
1769
|
// number of rows/cols for flash attention shader
|
|
1677
1770
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
|
1678
1771
|
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
|
1679
|
-
|
|
1772
|
+
|
|
1773
|
+
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
|
|
1774
|
+
if (hsv >= 512) {
|
|
1775
|
+
return 2;
|
|
1776
|
+
} else {
|
|
1777
|
+
return 8;
|
|
1778
|
+
}
|
|
1779
|
+
}
|
|
1680
1780
|
|
|
1681
1781
|
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
|
1682
1782
|
// 128 threads split into four subgroups, each subgroup does 1/4
|
|
@@ -1693,14 +1793,15 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
|
|
1693
1793
|
}
|
|
1694
1794
|
}
|
|
1695
1795
|
|
|
1696
|
-
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t
|
|
1796
|
+
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
|
|
1697
1797
|
GGML_UNUSED(clamp);
|
|
1798
|
+
GGML_UNUSED(hsv);
|
|
1698
1799
|
|
|
1699
1800
|
if (path == FA_SCALAR) {
|
|
1700
1801
|
if (small_rows) {
|
|
1701
1802
|
return {scalar_flash_attention_num_small_rows, 64};
|
|
1702
1803
|
} else {
|
|
1703
|
-
return {
|
|
1804
|
+
return {get_fa_scalar_num_large_rows(hsv), 32};
|
|
1704
1805
|
}
|
|
1705
1806
|
}
|
|
1706
1807
|
|
|
@@ -1718,8 +1819,12 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
|
|
|
1718
1819
|
}
|
|
1719
1820
|
|
|
1720
1821
|
// small cols to reduce register count
|
|
1721
|
-
if (ggml_is_quantized(type) ||
|
|
1722
|
-
|
|
1822
|
+
if (ggml_is_quantized(type) || hsk >= 256) {
|
|
1823
|
+
if (hsk >= 512) {
|
|
1824
|
+
return {32, 32};
|
|
1825
|
+
} else {
|
|
1826
|
+
return {64, 32};
|
|
1827
|
+
}
|
|
1723
1828
|
}
|
|
1724
1829
|
return {64, 64};
|
|
1725
1830
|
}
|
|
@@ -1761,7 +1866,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
1761
1866
|
const uint32_t warps = warptile[0] / warptile[10];
|
|
1762
1867
|
|
|
1763
1868
|
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
|
1764
|
-
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
|
|
1869
|
+
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
|
|
1765
1870
|
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
|
1766
1871
|
|
|
1767
1872
|
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
|
@@ -1886,10 +1991,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1886
1991
|
s_mmq_wg_denoms_k = { 32, 32, 1 };
|
|
1887
1992
|
|
|
1888
1993
|
// spec constants and tile sizes for quant matmul_id
|
|
1889
|
-
l_warptile_mmqid = { 256, 128,
|
|
1994
|
+
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
|
|
1890
1995
|
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
1891
1996
|
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
1892
|
-
l_mmqid_wg_denoms = { 128,
|
|
1997
|
+
l_mmqid_wg_denoms = { 128, 128, 1 };
|
|
1893
1998
|
m_mmqid_wg_denoms = { 128, 64, 1 };
|
|
1894
1999
|
s_mmqid_wg_denoms = { 128, 64, 1 };
|
|
1895
2000
|
|
|
@@ -2011,19 +2116,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2011
2116
|
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
|
2012
2117
|
};
|
|
2013
2118
|
|
|
2014
|
-
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t
|
|
2015
|
-
return {fa_rows_cols(path,
|
|
2119
|
+
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
|
2120
|
+
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
|
|
2016
2121
|
};
|
|
2017
2122
|
|
|
2018
|
-
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t
|
|
2123
|
+
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
|
2019
2124
|
// For large number of rows, 128 invocations seems to work best.
|
|
2020
2125
|
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
|
2021
2126
|
// can't use 256 for D==80.
|
|
2022
2127
|
// For scalar, use 128 (arbitrary)
|
|
2128
|
+
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
|
|
2129
|
+
const uint32_t D = (hsk|hsv);
|
|
2023
2130
|
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
|
2024
2131
|
? scalar_flash_attention_workgroup_size
|
|
2025
2132
|
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
|
2026
|
-
auto rows_cols = fa_rows_cols(path,
|
|
2133
|
+
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
|
|
2027
2134
|
|
|
2028
2135
|
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
|
2029
2136
|
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
|
@@ -2032,26 +2139,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2032
2139
|
|
|
2033
2140
|
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
|
2034
2141
|
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
|
2035
|
-
return {wg_size, rows_cols[0], rows_cols[1],
|
|
2142
|
+
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
|
|
2036
2143
|
};
|
|
2037
2144
|
|
|
2038
|
-
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX,
|
|
2039
|
-
ggml_vk_create_pipeline(device, device->
|
|
2040
|
-
ggml_vk_create_pipeline(device, device->
|
|
2041
|
-
ggml_vk_create_pipeline(device, device->
|
|
2042
|
-
ggml_vk_create_pipeline(device, device->
|
|
2043
|
-
ggml_vk_create_pipeline(device, device->
|
|
2044
|
-
ggml_vk_create_pipeline(device, device->
|
|
2045
|
-
ggml_vk_create_pipeline(device, device->
|
|
2046
|
-
ggml_vk_create_pipeline(device, device->
|
|
2145
|
+
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
|
|
2146
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2147
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2148
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2149
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2150
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2151
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2152
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2153
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2047
2154
|
|
|
2048
2155
|
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
|
2049
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
|
|
2050
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
|
|
2051
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
|
|
2052
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
|
|
2053
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
|
|
2054
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX,
|
|
2156
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
|
|
2157
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
|
|
2158
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
|
|
2159
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
|
|
2160
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
|
|
2161
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
|
|
2162
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
|
|
2163
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
|
|
2164
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
|
|
2055
2165
|
|
|
2056
2166
|
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
|
2057
2167
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
|
@@ -2641,7 +2751,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2641
2751
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2642
2752
|
|
|
2643
2753
|
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
|
2644
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2,
|
|
2754
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
|
2645
2755
|
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
|
2646
2756
|
|
|
2647
2757
|
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
|
@@ -2655,7 +2765,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2655
2765
|
|
|
2656
2766
|
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2657
2767
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2658
|
-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main",
|
|
2768
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
|
|
2769
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
|
|
2659
2770
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2660
2771
|
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2661
2772
|
|
|
@@ -2672,19 +2783,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2672
2783
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2673
2784
|
|
|
2674
2785
|
if (device->float_controls_rte_fp16) {
|
|
2675
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2676
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2677
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2678
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2679
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2680
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2786
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2787
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2788
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2789
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2790
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2791
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2681
2792
|
} else {
|
|
2682
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2683
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2684
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2685
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2686
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2687
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2793
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2794
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2795
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2796
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2797
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2798
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2799
|
+
}
|
|
2800
|
+
|
|
2801
|
+
if (device->float_controls_rte_fp16) {
|
|
2802
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2803
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2804
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2805
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2806
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2807
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2808
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2809
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2810
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2811
|
+
} else {
|
|
2812
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2813
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2814
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2815
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2816
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2817
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2818
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2819
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2820
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2688
2821
|
}
|
|
2689
2822
|
|
|
2690
2823
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
|
@@ -2724,7 +2857,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2724
2857
|
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
2725
2858
|
ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
2726
2859
|
|
|
2727
|
-
ggml_vk_create_pipeline(device, device->
|
|
2860
|
+
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
|
|
2861
|
+
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
|
|
2862
|
+
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
|
|
2728
2863
|
|
|
2729
2864
|
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2730
2865
|
|
|
@@ -2736,6 +2871,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2736
2871
|
|
|
2737
2872
|
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2738
2873
|
|
|
2874
|
+
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2875
|
+
|
|
2739
2876
|
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2740
2877
|
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2741
2878
|
|
|
@@ -2744,6 +2881,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2744
2881
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2745
2882
|
|
|
2746
2883
|
CREATE_UNARY(gelu)
|
|
2884
|
+
CREATE_UNARY(gelu_erf)
|
|
2747
2885
|
CREATE_UNARY(gelu_quick)
|
|
2748
2886
|
CREATE_UNARY(silu)
|
|
2749
2887
|
CREATE_UNARY(relu)
|
|
@@ -2751,6 +2889,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2751
2889
|
CREATE_UNARY(sigmoid)
|
|
2752
2890
|
#undef CREATE_UNARY
|
|
2753
2891
|
|
|
2892
|
+
#define CREATE_GLU(name) \
|
|
2893
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
|
2894
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
|
|
2895
|
+
|
|
2896
|
+
CREATE_GLU(geglu)
|
|
2897
|
+
CREATE_GLU(reglu)
|
|
2898
|
+
CREATE_GLU(swiglu)
|
|
2899
|
+
CREATE_GLU(geglu_erf)
|
|
2900
|
+
CREATE_GLU(geglu_quick)
|
|
2901
|
+
#undef CREATE_GLU
|
|
2902
|
+
|
|
2754
2903
|
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2755
2904
|
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2756
2905
|
|
|
@@ -3431,6 +3580,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
3431
3580
|
|
|
3432
3581
|
device->idx = idx;
|
|
3433
3582
|
|
|
3583
|
+
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
|
|
3584
|
+
|
|
3434
3585
|
return device;
|
|
3435
3586
|
}
|
|
3436
3587
|
|
|
@@ -3651,7 +3802,6 @@ static void ggml_vk_instance_init() {
|
|
|
3651
3802
|
|
|
3652
3803
|
}
|
|
3653
3804
|
|
|
3654
|
-
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
|
|
3655
3805
|
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
|
3656
3806
|
|
|
3657
3807
|
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
|
@@ -4124,6 +4274,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
|
|
|
4124
4274
|
return nullptr;
|
|
4125
4275
|
}
|
|
4126
4276
|
|
|
4277
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
4127
4278
|
device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
|
|
4128
4279
|
|
|
4129
4280
|
return buf->ptr;
|
|
@@ -4134,6 +4285,8 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
|
|
|
4134
4285
|
return;
|
|
4135
4286
|
}
|
|
4136
4287
|
VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
|
|
4288
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
4289
|
+
|
|
4137
4290
|
vk_buffer buf;
|
|
4138
4291
|
size_t index;
|
|
4139
4292
|
for (size_t i = 0; i < device->pinned_memory.size(); i++) {
|
|
@@ -4156,6 +4309,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
|
|
|
4156
4309
|
}
|
|
4157
4310
|
|
|
4158
4311
|
static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
|
|
4312
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
4159
4313
|
buf = nullptr;
|
|
4160
4314
|
buf_offset = 0;
|
|
4161
4315
|
for (size_t i = 0; i < device->pinned_memory.size(); i++) {
|
|
@@ -4457,7 +4611,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
|
|
|
4457
4611
|
memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
|
|
4458
4612
|
}
|
|
4459
4613
|
} else {
|
|
4460
|
-
std::lock_guard<std::
|
|
4614
|
+
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
|
|
4461
4615
|
|
|
4462
4616
|
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
|
|
4463
4617
|
ggml_vk_ctx_begin(dst->device, subctx);
|
|
@@ -4548,7 +4702,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
|
|
|
4548
4702
|
|
|
4549
4703
|
memcpy(dst, (uint8_t *) src->ptr + offset, size);
|
|
4550
4704
|
} else {
|
|
4551
|
-
std::lock_guard<std::
|
|
4705
|
+
std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
|
|
4552
4706
|
|
|
4553
4707
|
vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
|
|
4554
4708
|
ggml_vk_ctx_begin(src->device, subctx);
|
|
@@ -4578,7 +4732,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds
|
|
|
4578
4732
|
|
|
4579
4733
|
static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
|
|
4580
4734
|
if (src->device == dst->device) {
|
|
4581
|
-
std::lock_guard<std::
|
|
4735
|
+
std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
|
|
4582
4736
|
VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
|
|
4583
4737
|
// Copy within the device
|
|
4584
4738
|
vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
|
|
@@ -4613,7 +4767,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t
|
|
|
4613
4767
|
static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
|
|
4614
4768
|
VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
|
|
4615
4769
|
|
|
4616
|
-
std::lock_guard<std::
|
|
4770
|
+
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
|
|
4617
4771
|
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
|
|
4618
4772
|
ggml_vk_ctx_begin(dst->device, subctx);
|
|
4619
4773
|
subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
|
|
@@ -4840,9 +4994,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
4840
4994
|
// type size must be exactly 2 or 4.
|
|
4841
4995
|
GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
|
|
4842
4996
|
if ((ggml_type_size(src->type) % 4) == 0) {
|
|
4843
|
-
|
|
4997
|
+
if (contig) {
|
|
4998
|
+
return ctx->device->pipeline_contig_cpy_f32_f32;
|
|
4999
|
+
} else {
|
|
5000
|
+
return ctx->device->pipeline_cpy_f32_f32;
|
|
5001
|
+
}
|
|
4844
5002
|
} else {
|
|
4845
|
-
|
|
5003
|
+
if (contig) {
|
|
5004
|
+
return ctx->device->pipeline_contig_cpy_f16_f16;
|
|
5005
|
+
} else {
|
|
5006
|
+
return ctx->device->pipeline_cpy_f16_f16;
|
|
5007
|
+
}
|
|
4846
5008
|
}
|
|
4847
5009
|
}
|
|
4848
5010
|
|
|
@@ -4903,7 +5065,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4903
5065
|
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
|
4904
5066
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
4905
5067
|
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
|
|
4906
|
-
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
|
|
5068
|
+
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
|
|
4907
5069
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
|
4908
5070
|
|
|
4909
5071
|
const uint64_t ne00 = src0->ne[0];
|
|
@@ -5131,7 +5293,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
5131
5293
|
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
|
5132
5294
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
5133
5295
|
std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)");
|
|
5134
|
-
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
|
|
5296
|
+
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
|
|
5135
5297
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
|
5136
5298
|
|
|
5137
5299
|
const uint64_t ne00 = src0->ne[0];
|
|
@@ -5732,7 +5894,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
5732
5894
|
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
|
|
5733
5895
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
5734
5896
|
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
|
|
5735
|
-
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
|
|
5897
|
+
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
|
|
5736
5898
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
|
5737
5899
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
|
5738
5900
|
|
|
@@ -5926,14 +6088,60 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5926
6088
|
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
|
5927
6089
|
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
|
|
5928
6090
|
} else {
|
|
5929
|
-
|
|
6091
|
+
// Split based on number of ids, to fit in shared memory
|
|
6092
|
+
const uint32_t nei0 = (uint32_t)src2->ne[0];
|
|
6093
|
+
const uint32_t nei1 = (uint32_t)src2->ne[1];
|
|
6094
|
+
|
|
6095
|
+
GGML_ASSERT(nei0 <= 4096);
|
|
6096
|
+
const uint32_t split_size = std::min(nei1, 4096u / nei0);
|
|
6097
|
+
|
|
6098
|
+
ggml_tensor src1_copy = *src1;
|
|
6099
|
+
ggml_tensor src2_copy = *src2;
|
|
6100
|
+
ggml_tensor dst_copy = *dst;
|
|
6101
|
+
|
|
6102
|
+
for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
|
|
6103
|
+
const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
|
|
6104
|
+
|
|
6105
|
+
src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
|
|
6106
|
+
src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
|
|
6107
|
+
dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
|
|
6108
|
+
|
|
6109
|
+
src1_copy.ne[2] = n_tokens;
|
|
6110
|
+
src2_copy.ne[1] = n_tokens;
|
|
6111
|
+
dst_copy.ne[2] = n_tokens;
|
|
6112
|
+
|
|
6113
|
+
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
|
|
6114
|
+
}
|
|
5930
6115
|
}
|
|
5931
6116
|
}
|
|
5932
6117
|
|
|
5933
|
-
static bool
|
|
6118
|
+
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
|
|
6119
|
+
// Needs to be kept up to date on shader changes
|
|
6120
|
+
GGML_UNUSED(hsv);
|
|
6121
|
+
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
|
6122
|
+
const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
|
|
6123
|
+
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
6124
|
+
|
|
6125
|
+
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
6126
|
+
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
|
|
6127
|
+
|
|
6128
|
+
const uint32_t masksh = Bc * Br * sizeof(float);
|
|
6129
|
+
|
|
6130
|
+
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
|
|
6131
|
+
|
|
6132
|
+
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
|
|
6133
|
+
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
6134
|
+
|
|
6135
|
+
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
|
6136
|
+
|
|
6137
|
+
return supported;
|
|
6138
|
+
}
|
|
6139
|
+
|
|
6140
|
+
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
|
|
5934
6141
|
// Needs to be kept up to date on shader changes
|
|
6142
|
+
GGML_UNUSED(hsv);
|
|
5935
6143
|
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
|
5936
|
-
const uint32_t Br =
|
|
6144
|
+
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
|
|
5937
6145
|
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
5938
6146
|
|
|
5939
6147
|
const uint32_t acctype = f32acc ? 4 : 2;
|
|
@@ -5942,12 +6150,12 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|
|
5942
6150
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
5943
6151
|
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
|
5944
6152
|
|
|
5945
|
-
const uint32_t Qf = Br * (
|
|
6153
|
+
const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
|
|
5946
6154
|
|
|
5947
|
-
const uint32_t sfshstride = (
|
|
6155
|
+
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
|
|
5948
6156
|
const uint32_t sfsh = Bc * sfshstride * acctype;
|
|
5949
6157
|
|
|
5950
|
-
const uint32_t kshstride =
|
|
6158
|
+
const uint32_t kshstride = hsk / 4 + 2;
|
|
5951
6159
|
const uint32_t ksh = Bc * kshstride * f16vec4;
|
|
5952
6160
|
|
|
5953
6161
|
const uint32_t slope = Br * sizeof(float);
|
|
@@ -5955,7 +6163,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|
|
5955
6163
|
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
|
5956
6164
|
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
5957
6165
|
|
|
5958
|
-
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(
|
|
6166
|
+
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
|
5959
6167
|
|
|
5960
6168
|
return supported;
|
|
5961
6169
|
}
|
|
@@ -5977,13 +6185,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5977
6185
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
5978
6186
|
|
|
5979
6187
|
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
|
5980
|
-
const uint32_t
|
|
6188
|
+
const uint32_t nem2 = mask ? mask->ne[2] : 0;
|
|
6189
|
+
const uint32_t nem3 = mask ? mask->ne[3] : 0;
|
|
5981
6190
|
|
|
5982
|
-
const uint32_t
|
|
6191
|
+
const uint32_t HSK = nek0;
|
|
6192
|
+
const uint32_t HSV = nev0;
|
|
5983
6193
|
uint32_t N = neq1;
|
|
5984
6194
|
const uint32_t KV = nek1;
|
|
5985
6195
|
|
|
5986
|
-
GGML_ASSERT(ne0 ==
|
|
6196
|
+
GGML_ASSERT(ne0 == HSV);
|
|
5987
6197
|
GGML_ASSERT(ne2 == N);
|
|
5988
6198
|
|
|
5989
6199
|
// input tensor rows must be contiguous
|
|
@@ -5991,12 +6201,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5991
6201
|
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
|
5992
6202
|
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
|
5993
6203
|
|
|
5994
|
-
GGML_ASSERT(neq0 ==
|
|
5995
|
-
GGML_ASSERT(nek0 == D);
|
|
5996
|
-
GGML_ASSERT(nev0 == D);
|
|
6204
|
+
GGML_ASSERT(neq0 == HSK);
|
|
5997
6205
|
|
|
5998
6206
|
GGML_ASSERT(neq1 == N);
|
|
5999
|
-
GGML_ASSERT(nev0 == D);
|
|
6000
6207
|
|
|
6001
6208
|
GGML_ASSERT(nev1 == nek1);
|
|
6002
6209
|
|
|
@@ -6017,7 +6224,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6017
6224
|
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
|
6018
6225
|
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
|
6019
6226
|
|
|
6020
|
-
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device,
|
|
6227
|
+
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
|
|
6021
6228
|
|
|
6022
6229
|
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
|
6023
6230
|
path = FA_SCALAR;
|
|
@@ -6037,7 +6244,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6037
6244
|
case FA_SCALAR:
|
|
6038
6245
|
case FA_COOPMAT1:
|
|
6039
6246
|
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
|
6040
|
-
max_gqa =
|
|
6247
|
+
max_gqa = get_fa_scalar_num_large_rows(HSV);
|
|
6041
6248
|
break;
|
|
6042
6249
|
case FA_COOPMAT2:
|
|
6043
6250
|
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
|
@@ -6047,7 +6254,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6047
6254
|
}
|
|
6048
6255
|
|
|
6049
6256
|
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
|
6050
|
-
qk_ratio * nek2 == neq2 && nek2 == nev2 &&
|
|
6257
|
+
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
|
|
6051
6258
|
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
|
6052
6259
|
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
|
6053
6260
|
// and change addressing calculations to index Q's dimension 2.
|
|
@@ -6070,47 +6277,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6070
6277
|
path = FA_SCALAR;
|
|
6071
6278
|
}
|
|
6072
6279
|
|
|
6280
|
+
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
|
|
6281
|
+
if (path == FA_SCALAR &&
|
|
6282
|
+
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
|
|
6283
|
+
small_rows = true;
|
|
6284
|
+
}
|
|
6285
|
+
|
|
6073
6286
|
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
|
6074
6287
|
|
|
6288
|
+
FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
|
|
6289
|
+
|
|
6075
6290
|
switch (path) {
|
|
6076
6291
|
case FA_SCALAR:
|
|
6077
|
-
|
|
6078
|
-
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
|
6079
|
-
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
|
6080
|
-
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
|
6081
|
-
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
|
6082
|
-
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
|
6083
|
-
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
|
6084
|
-
default:
|
|
6085
|
-
GGML_ASSERT(!"unsupported D value");
|
|
6086
|
-
return;
|
|
6087
|
-
}
|
|
6292
|
+
pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
|
|
6088
6293
|
break;
|
|
6089
6294
|
case FA_COOPMAT1:
|
|
6090
|
-
|
|
6091
|
-
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6092
|
-
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6093
|
-
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6094
|
-
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6095
|
-
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6096
|
-
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6097
|
-
default:
|
|
6098
|
-
GGML_ASSERT(!"unsupported D value");
|
|
6099
|
-
return;
|
|
6100
|
-
}
|
|
6295
|
+
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
|
|
6101
6296
|
break;
|
|
6102
6297
|
case FA_COOPMAT2:
|
|
6103
|
-
|
|
6104
|
-
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6105
|
-
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6106
|
-
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6107
|
-
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6108
|
-
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6109
|
-
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6110
|
-
default:
|
|
6111
|
-
GGML_ASSERT(!"unsupported D value");
|
|
6112
|
-
return;
|
|
6113
|
-
}
|
|
6298
|
+
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
|
|
6114
6299
|
break;
|
|
6115
6300
|
default:
|
|
6116
6301
|
GGML_ASSERT(0);
|
|
@@ -6138,21 +6323,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6138
6323
|
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
|
6139
6324
|
|
|
6140
6325
|
// Try to use split_k when KV is large enough to be worth the overhead
|
|
6141
|
-
if (workgroups_x == 1 && shader_core_count > 0
|
|
6326
|
+
if (workgroups_x == 1 && shader_core_count > 0) {
|
|
6142
6327
|
// Try to run two workgroups per SM.
|
|
6143
|
-
split_k =
|
|
6328
|
+
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
|
6144
6329
|
if (split_k > 1) {
|
|
6145
6330
|
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
|
6146
6331
|
// of "align", so recompute split_k based on that.
|
|
6147
|
-
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
|
|
6332
|
+
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
|
|
6148
6333
|
split_k = CEIL_DIV(KV, split_kv);
|
|
6149
6334
|
workgroups_x = split_k;
|
|
6150
6335
|
}
|
|
6151
6336
|
}
|
|
6152
6337
|
|
|
6153
|
-
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
|
|
6154
|
-
// and the per-row m and L values (ne1 rows).
|
|
6155
|
-
const uint64_t split_k_size = split_k > 1 ? (
|
|
6338
|
+
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
|
6339
|
+
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
|
6340
|
+
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
|
6156
6341
|
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
|
6157
6342
|
GGML_ABORT("Requested preallocation size is too large");
|
|
6158
6343
|
}
|
|
@@ -6239,18 +6424,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6239
6424
|
}
|
|
6240
6425
|
}
|
|
6241
6426
|
|
|
6427
|
+
uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
|
|
6428
|
+
|
|
6242
6429
|
const vk_flash_attn_push_constants pc = { N, KV,
|
|
6243
6430
|
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
|
6244
6431
|
(uint32_t)neq2, (uint32_t)neq3,
|
|
6245
6432
|
(uint32_t)nek2, (uint32_t)nek3,
|
|
6246
6433
|
(uint32_t)nev2, (uint32_t)nev3,
|
|
6247
|
-
nem1,
|
|
6434
|
+
nem1, nem2, nem3,
|
|
6248
6435
|
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
|
|
6249
6436
|
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
|
|
6250
6437
|
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
|
6251
|
-
nbm1,
|
|
6252
6438
|
scale, max_bias, logit_softcap,
|
|
6253
|
-
|
|
6439
|
+
mask_n_head_log2, m0, m1,
|
|
6254
6440
|
gqa_ratio, split_kv, split_k };
|
|
6255
6441
|
|
|
6256
6442
|
ggml_vk_sync_buffers(subctx);
|
|
@@ -6271,13 +6457,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6271
6457
|
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
|
6272
6458
|
|
|
6273
6459
|
ggml_vk_sync_buffers(subctx);
|
|
6274
|
-
const std::array<uint32_t,
|
|
6460
|
+
const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
|
|
6275
6461
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
|
6276
6462
|
{
|
|
6277
6463
|
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
|
6278
6464
|
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
|
6279
6465
|
},
|
|
6280
|
-
pc2, { (uint32_t)ne1,
|
|
6466
|
+
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
|
6281
6467
|
} else {
|
|
6282
6468
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
6283
6469
|
{
|
|
@@ -6353,8 +6539,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6353
6539
|
}
|
|
6354
6540
|
return nullptr;
|
|
6355
6541
|
case GGML_OP_UPSCALE:
|
|
6356
|
-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
6357
|
-
|
|
6542
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6543
|
+
int mode = ggml_get_op_params_i32(dst, 0);
|
|
6544
|
+
switch (mode) {
|
|
6545
|
+
case GGML_SCALE_MODE_NEAREST:
|
|
6546
|
+
return ctx->device->pipeline_upscale_nearest_f32;
|
|
6547
|
+
case GGML_SCALE_MODE_BILINEAR:
|
|
6548
|
+
return ctx->device->pipeline_upscale_bilinear_f32;
|
|
6549
|
+
case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
|
|
6550
|
+
return ctx->device->pipeline_upscale_bilinear_ac_f32;
|
|
6551
|
+
}
|
|
6358
6552
|
}
|
|
6359
6553
|
return nullptr;
|
|
6360
6554
|
case GGML_OP_SCALE:
|
|
@@ -6387,6 +6581,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6387
6581
|
return ctx->device->pipeline_pad_f32;
|
|
6388
6582
|
}
|
|
6389
6583
|
return nullptr;
|
|
6584
|
+
case GGML_OP_ROLL:
|
|
6585
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6586
|
+
return ctx->device->pipeline_roll_f32;
|
|
6587
|
+
}
|
|
6588
|
+
return nullptr;
|
|
6390
6589
|
case GGML_OP_REPEAT:
|
|
6391
6590
|
if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
|
|
6392
6591
|
return ctx->device->pipeline_repeat_f32;
|
|
@@ -6401,6 +6600,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6401
6600
|
case GGML_OP_CONT:
|
|
6402
6601
|
case GGML_OP_DUP:
|
|
6403
6602
|
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
|
|
6603
|
+
case GGML_OP_SET_ROWS:
|
|
6604
|
+
return ctx->device->pipeline_set_rows[dst->type];
|
|
6404
6605
|
case GGML_OP_SILU_BACK:
|
|
6405
6606
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6406
6607
|
return ctx->device->pipeline_silu_back_f32;
|
|
@@ -6418,7 +6619,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6418
6619
|
return nullptr;
|
|
6419
6620
|
case GGML_OP_RMS_NORM:
|
|
6420
6621
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6421
|
-
return ctx->device->pipeline_rms_norm_f32;
|
|
6622
|
+
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
|
|
6422
6623
|
}
|
|
6423
6624
|
return nullptr;
|
|
6424
6625
|
case GGML_OP_RMS_NORM_BACK:
|
|
@@ -6443,6 +6644,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6443
6644
|
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
|
|
6444
6645
|
case GGML_UNARY_OP_GELU:
|
|
6445
6646
|
return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
|
|
6647
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
6648
|
+
return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
|
|
6446
6649
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
6447
6650
|
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
|
|
6448
6651
|
case GGML_UNARY_OP_RELU:
|
|
@@ -6455,6 +6658,28 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6455
6658
|
break;
|
|
6456
6659
|
}
|
|
6457
6660
|
return nullptr;
|
|
6661
|
+
case GGML_OP_GLU:
|
|
6662
|
+
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
|
|
6663
|
+
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
|
|
6664
|
+
(src0->type != dst->type)) {
|
|
6665
|
+
return nullptr;
|
|
6666
|
+
}
|
|
6667
|
+
|
|
6668
|
+
switch (ggml_get_glu_op(dst)) {
|
|
6669
|
+
case GGML_GLU_OP_GEGLU:
|
|
6670
|
+
return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
|
|
6671
|
+
case GGML_GLU_OP_REGLU:
|
|
6672
|
+
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
|
6673
|
+
case GGML_GLU_OP_SWIGLU:
|
|
6674
|
+
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
|
6675
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
6676
|
+
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
|
|
6677
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
6678
|
+
return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
|
|
6679
|
+
default:
|
|
6680
|
+
break;
|
|
6681
|
+
}
|
|
6682
|
+
return nullptr;
|
|
6458
6683
|
case GGML_OP_DIAG_MASK_INF:
|
|
6459
6684
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6460
6685
|
return ctx->device->pipeline_diag_mask_inf_f32;
|
|
@@ -6615,6 +6840,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
|
6615
6840
|
case GGML_OP_RMS_NORM:
|
|
6616
6841
|
case GGML_OP_CONV_2D_DW:
|
|
6617
6842
|
case GGML_OP_IM2COL:
|
|
6843
|
+
case GGML_OP_SET_ROWS:
|
|
6618
6844
|
return true;
|
|
6619
6845
|
default:
|
|
6620
6846
|
return false;
|
|
@@ -6909,12 +7135,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6909
7135
|
case GGML_OP_COS:
|
|
6910
7136
|
case GGML_OP_CLAMP:
|
|
6911
7137
|
case GGML_OP_PAD:
|
|
7138
|
+
case GGML_OP_ROLL:
|
|
6912
7139
|
case GGML_OP_REPEAT:
|
|
6913
7140
|
case GGML_OP_REPEAT_BACK:
|
|
6914
7141
|
case GGML_OP_CPY:
|
|
6915
7142
|
case GGML_OP_CONCAT:
|
|
6916
7143
|
case GGML_OP_UPSCALE:
|
|
6917
7144
|
case GGML_OP_UNARY:
|
|
7145
|
+
case GGML_OP_GLU:
|
|
6918
7146
|
case GGML_OP_CONV_2D_DW:
|
|
6919
7147
|
{
|
|
6920
7148
|
uint32_t ne = ggml_nelements(dst);
|
|
@@ -6927,6 +7155,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6927
7155
|
ne *= ggml_type_size(src0->type) / 2;
|
|
6928
7156
|
}
|
|
6929
7157
|
}
|
|
7158
|
+
// copy_to_quant has block size of 32, and each thread does QUANT_K elements.
|
|
7159
|
+
// Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
|
|
7160
|
+
// So divide by block size here before splitting into 512x512 groups.
|
|
7161
|
+
if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
|
7162
|
+
ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
|
|
7163
|
+
}
|
|
6930
7164
|
if (ne > 262144) {
|
|
6931
7165
|
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
|
6932
7166
|
} else if (ne > 512) {
|
|
@@ -6935,6 +7169,25 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6935
7169
|
elements = { ne, 1, 1 };
|
|
6936
7170
|
}
|
|
6937
7171
|
} break;
|
|
7172
|
+
case GGML_OP_SET_ROWS:
|
|
7173
|
+
{
|
|
7174
|
+
uint32_t ne = ggml_nelements(src0);
|
|
7175
|
+
if (ggml_is_quantized(dst->type)) {
|
|
7176
|
+
// quants run 32 threads each doing QUANT_K elements
|
|
7177
|
+
ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
|
|
7178
|
+
} else {
|
|
7179
|
+
// scalar types do one element per thread, running 512 threads
|
|
7180
|
+
ne = CEIL_DIV(ne, 512);
|
|
7181
|
+
}
|
|
7182
|
+
if (ne > 262144) {
|
|
7183
|
+
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
|
7184
|
+
} else if (ne > 512) {
|
|
7185
|
+
elements = { 512, CEIL_DIV(ne, 512), 1 };
|
|
7186
|
+
} else {
|
|
7187
|
+
elements = { ne, 1, 1 };
|
|
7188
|
+
}
|
|
7189
|
+
}
|
|
7190
|
+
break;
|
|
6938
7191
|
default:
|
|
6939
7192
|
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
|
|
6940
7193
|
break;
|
|
@@ -6955,7 +7208,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6955
7208
|
}
|
|
6956
7209
|
}
|
|
6957
7210
|
|
|
6958
|
-
if (op == GGML_OP_SOFT_MAX) {
|
|
7211
|
+
if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
|
|
6959
7212
|
// Empty src1 is possible in soft_max, but the shader needs a buffer
|
|
6960
7213
|
vk_subbuffer subbuf_y;
|
|
6961
7214
|
if (use_src1) {
|
|
@@ -7344,14 +7597,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
7344
7597
|
|
|
7345
7598
|
static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7346
7599
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7600
|
+
const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
|
|
7601
|
+
|
|
7602
|
+
float sf0 = (float)dst->ne[0] / src0->ne[0];
|
|
7603
|
+
float sf1 = (float)dst->ne[1] / src0->ne[1];
|
|
7604
|
+
float sf2 = (float)dst->ne[2] / src0->ne[2];
|
|
7605
|
+
float sf3 = (float)dst->ne[3] / src0->ne[3];
|
|
7347
7606
|
|
|
7348
|
-
|
|
7349
|
-
|
|
7350
|
-
|
|
7351
|
-
|
|
7607
|
+
if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7608
|
+
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
|
|
7609
|
+
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
|
|
7610
|
+
}
|
|
7352
7611
|
|
|
7353
7612
|
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
|
|
7354
7613
|
(uint32_t)ggml_nelements(dst), 0, 0,
|
|
7614
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
|
|
7355
7615
|
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7356
7616
|
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
|
|
7357
7617
|
sf0, sf1, sf2, sf3,
|
|
@@ -7359,123 +7619,64 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
7359
7619
|
}
|
|
7360
7620
|
|
|
7361
7621
|
static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7362
|
-
|
|
7363
|
-
|
|
7364
|
-
|
|
7622
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
|
7623
|
+
p.param1 = ggml_get_op_params_f32(dst, 0);
|
|
7624
|
+
p.param2 = ggml_get_op_params_f32(dst, 1);
|
|
7365
7625
|
|
|
7366
|
-
ggml_vk_op_f32
|
|
7367
|
-
(uint32_t)ggml_nelements(src0),
|
|
7368
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7369
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7370
|
-
0,
|
|
7371
|
-
op_params[0], 0.0f,
|
|
7372
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7373
|
-
}, dryrun);
|
|
7626
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
|
|
7374
7627
|
}
|
|
7375
7628
|
|
|
7376
7629
|
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7377
|
-
|
|
7378
|
-
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7379
|
-
|
|
7380
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
|
|
7381
|
-
(uint32_t)ggml_nelements(src0),
|
|
7382
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7383
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7384
|
-
0,
|
|
7385
|
-
0.0f, 0.0f,
|
|
7386
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7387
|
-
}, dryrun);
|
|
7630
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
|
7388
7631
|
}
|
|
7389
7632
|
|
|
7390
7633
|
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7391
|
-
|
|
7392
|
-
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7393
|
-
|
|
7394
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
|
|
7395
|
-
(uint32_t)ggml_nelements(src0),
|
|
7396
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7397
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7398
|
-
0,
|
|
7399
|
-
0.0f, 0.0f,
|
|
7400
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7401
|
-
}, dryrun);
|
|
7634
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
|
7402
7635
|
}
|
|
7403
7636
|
|
|
7404
7637
|
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7405
|
-
|
|
7406
|
-
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7407
|
-
|
|
7408
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
|
|
7409
|
-
(uint32_t)ggml_nelements(src0),
|
|
7410
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7411
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7412
|
-
0,
|
|
7413
|
-
0.0f, 0.0f,
|
|
7414
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7415
|
-
}, dryrun);
|
|
7638
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
|
7416
7639
|
}
|
|
7417
7640
|
|
|
7418
7641
|
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7419
|
-
|
|
7420
|
-
|
|
7421
|
-
|
|
7642
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
|
7643
|
+
p.param1 = ggml_get_op_params_f32(dst, 0);
|
|
7644
|
+
p.param2 = ggml_get_op_params_f32(dst, 1);
|
|
7422
7645
|
|
|
7423
|
-
ggml_vk_op_f32
|
|
7424
|
-
(uint32_t)ggml_nelements(src0),
|
|
7425
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7426
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7427
|
-
0,
|
|
7428
|
-
op_params[0], op_params[1],
|
|
7429
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7430
|
-
}, dryrun);
|
|
7646
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
|
|
7431
7647
|
}
|
|
7432
7648
|
|
|
7433
7649
|
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7434
|
-
|
|
7435
|
-
|
|
7650
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
|
7651
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
|
|
7652
|
+
}
|
|
7436
7653
|
|
|
7437
|
-
|
|
7438
|
-
|
|
7439
|
-
|
|
7440
|
-
|
|
7441
|
-
|
|
7442
|
-
|
|
7443
|
-
|
|
7444
|
-
|
|
7654
|
+
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7655
|
+
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
|
|
7656
|
+
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
|
|
7657
|
+
const int32_t s2 = ggml_get_op_params_i32(dst, 2);
|
|
7658
|
+
const int32_t s3 = ggml_get_op_params_i32(dst, 3);
|
|
7659
|
+
const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
|
|
7660
|
+
const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
|
|
7661
|
+
|
|
7662
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
|
7663
|
+
memcpy(&p.param1, &s01_packed, sizeof(float));
|
|
7664
|
+
memcpy(&p.param2, &s23_packed, sizeof(float));
|
|
7665
|
+
|
|
7666
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
|
|
7445
7667
|
}
|
|
7446
7668
|
|
|
7447
7669
|
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7448
|
-
|
|
7449
|
-
|
|
7450
|
-
|
|
7451
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
|
|
7452
|
-
(uint32_t)ggml_nelements(dst),
|
|
7453
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7454
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7455
|
-
0,
|
|
7456
|
-
0.0f, 0.0f,
|
|
7457
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7458
|
-
}, dryrun);
|
|
7670
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
|
7671
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
|
|
7459
7672
|
}
|
|
7460
7673
|
|
|
7461
7674
|
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7462
|
-
|
|
7463
|
-
|
|
7464
|
-
|
|
7465
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
|
|
7466
|
-
(uint32_t)ggml_nelements(dst),
|
|
7467
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7468
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7469
|
-
0,
|
|
7470
|
-
0.0f, 0.0f,
|
|
7471
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7472
|
-
}, dryrun);
|
|
7675
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
|
7676
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
|
|
7473
7677
|
}
|
|
7474
7678
|
|
|
7475
7679
|
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7476
|
-
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7477
|
-
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7478
|
-
|
|
7479
7680
|
uint32_t ne = (uint32_t)ggml_nelements(src0);
|
|
7480
7681
|
if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
|
7481
7682
|
// Convert from number of logical elements to 2- or 4-byte units.
|
|
@@ -7487,13 +7688,22 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
7487
7688
|
}
|
|
7488
7689
|
}
|
|
7489
7690
|
|
|
7490
|
-
|
|
7491
|
-
|
|
7492
|
-
|
|
7493
|
-
|
|
7691
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
|
|
7692
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
|
|
7693
|
+
}
|
|
7694
|
+
|
|
7695
|
+
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
7696
|
+
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7697
|
+
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
7698
|
+
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7699
|
+
|
|
7700
|
+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
|
|
7701
|
+
(uint32_t)ggml_nelements(src0),
|
|
7702
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7703
|
+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
|
7704
|
+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7494
7705
|
0,
|
|
7495
|
-
0.0f, 0.0f,
|
|
7496
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7706
|
+
0.0f, 0.0f, 0,
|
|
7497
7707
|
}, dryrun);
|
|
7498
7708
|
}
|
|
7499
7709
|
|
|
@@ -7518,18 +7728,18 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
7518
7728
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
|
|
7519
7729
|
}
|
|
7520
7730
|
|
|
7521
|
-
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7522
|
-
float * op_params = (float *)dst->op_params;
|
|
7731
|
+
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
|
|
7523
7732
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7733
|
+
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
7524
7734
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7525
7735
|
|
|
7526
|
-
ggml_vk_op_f32<
|
|
7736
|
+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
|
|
7527
7737
|
(uint32_t)ggml_nelements(src0),
|
|
7528
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
|
7529
|
-
(uint32_t)
|
|
7738
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7739
|
+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
|
7740
|
+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7530
7741
|
0,
|
|
7531
|
-
op_params[0], 0.0f,
|
|
7532
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7742
|
+
op_params[0], 0.0f, 0,
|
|
7533
7743
|
}, dryrun);
|
|
7534
7744
|
}
|
|
7535
7745
|
|
|
@@ -7547,6 +7757,25 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
|
|
|
7547
7757
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
|
7548
7758
|
}
|
|
7549
7759
|
|
|
7760
|
+
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
7761
|
+
const bool swapped = (bool)dst->op_params[1];
|
|
7762
|
+
const bool split = src1 != nullptr;
|
|
7763
|
+
|
|
7764
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
7765
|
+
|
|
7766
|
+
if (!split) {
|
|
7767
|
+
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
|
|
7768
|
+
} else {
|
|
7769
|
+
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
|
|
7770
|
+
GGML_ASSERT(src0->ne[0] == dst->ne[0]);
|
|
7771
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
7772
|
+
}
|
|
7773
|
+
|
|
7774
|
+
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
|
|
7775
|
+
|
|
7776
|
+
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
|
|
7777
|
+
}
|
|
7778
|
+
|
|
7550
7779
|
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7551
7780
|
int32_t * op_params = (int32_t *)dst->op_params;
|
|
7552
7781
|
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
|
|
@@ -7562,7 +7791,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
7562
7791
|
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
|
|
7563
7792
|
const uint32_t nrows_y = (uint32_t)src0->ne[1];
|
|
7564
7793
|
|
|
7565
|
-
const uint32_t
|
|
7794
|
+
const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
|
|
7795
|
+
const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
|
|
7796
|
+
const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
|
|
7797
|
+
const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
|
|
7798
|
+
const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
|
|
7799
|
+
|
|
7800
|
+
const uint32_t n_head_kv = src0->ne[2];
|
|
7566
7801
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
|
7567
7802
|
|
|
7568
7803
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
@@ -7571,6 +7806,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
7571
7806
|
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
|
|
7572
7807
|
ncols,
|
|
7573
7808
|
src1 != nullptr ? nrows_y : (uint32_t)0,
|
|
7809
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
|
7810
|
+
ne12, ne13,
|
|
7811
|
+
nb11, nb12, nb13,
|
|
7574
7812
|
scale, max_bias,
|
|
7575
7813
|
m0, m1,
|
|
7576
7814
|
n_head_log2,
|
|
@@ -8720,11 +8958,12 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|
|
8720
8958
|
}
|
|
8721
8959
|
}
|
|
8722
8960
|
|
|
8723
|
-
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
|
|
8961
|
+
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
|
|
8724
8962
|
|
|
8725
8963
|
// Returns true if node has enqueued work into the queue, false otherwise
|
|
8726
8964
|
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
|
8727
|
-
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx,
|
|
8965
|
+
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
|
|
8966
|
+
ggml_tensor * node = cgraph->nodes[node_idx];
|
|
8728
8967
|
if (ggml_is_empty(node) || !node->buffer) {
|
|
8729
8968
|
return false;
|
|
8730
8969
|
}
|
|
@@ -8749,6 +8988,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8749
8988
|
switch (ggml_get_unary_op(node)) {
|
|
8750
8989
|
case GGML_UNARY_OP_SILU:
|
|
8751
8990
|
case GGML_UNARY_OP_GELU:
|
|
8991
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
8752
8992
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
8753
8993
|
case GGML_UNARY_OP_RELU:
|
|
8754
8994
|
case GGML_UNARY_OP_TANH:
|
|
@@ -8758,6 +8998,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8758
8998
|
return false;
|
|
8759
8999
|
}
|
|
8760
9000
|
break;
|
|
9001
|
+
case GGML_OP_GLU:
|
|
9002
|
+
switch (ggml_get_glu_op(node)) {
|
|
9003
|
+
case GGML_GLU_OP_GEGLU:
|
|
9004
|
+
case GGML_GLU_OP_REGLU:
|
|
9005
|
+
case GGML_GLU_OP_SWIGLU:
|
|
9006
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9007
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9008
|
+
break;
|
|
9009
|
+
default:
|
|
9010
|
+
return false;
|
|
9011
|
+
}
|
|
9012
|
+
break;
|
|
8761
9013
|
case GGML_OP_REPEAT:
|
|
8762
9014
|
case GGML_OP_REPEAT_BACK:
|
|
8763
9015
|
case GGML_OP_GET_ROWS:
|
|
@@ -8774,7 +9026,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8774
9026
|
case GGML_OP_COS:
|
|
8775
9027
|
case GGML_OP_CLAMP:
|
|
8776
9028
|
case GGML_OP_PAD:
|
|
9029
|
+
case GGML_OP_ROLL:
|
|
8777
9030
|
case GGML_OP_CPY:
|
|
9031
|
+
case GGML_OP_SET_ROWS:
|
|
8778
9032
|
case GGML_OP_CONT:
|
|
8779
9033
|
case GGML_OP_DUP:
|
|
8780
9034
|
case GGML_OP_SILU_BACK:
|
|
@@ -8841,6 +9095,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8841
9095
|
case GGML_OP_CLAMP:
|
|
8842
9096
|
case GGML_OP_PAD:
|
|
8843
9097
|
case GGML_OP_CPY:
|
|
9098
|
+
case GGML_OP_SET_ROWS:
|
|
8844
9099
|
case GGML_OP_CONT:
|
|
8845
9100
|
case GGML_OP_DUP:
|
|
8846
9101
|
case GGML_OP_SILU_BACK:
|
|
@@ -8850,6 +9105,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8850
9105
|
case GGML_OP_RMS_NORM_BACK:
|
|
8851
9106
|
case GGML_OP_L2_NORM:
|
|
8852
9107
|
case GGML_OP_UNARY:
|
|
9108
|
+
case GGML_OP_GLU:
|
|
8853
9109
|
case GGML_OP_DIAG_MASK_INF:
|
|
8854
9110
|
case GGML_OP_SOFT_MAX:
|
|
8855
9111
|
case GGML_OP_SOFT_MAX_BACK:
|
|
@@ -8942,12 +9198,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8942
9198
|
case GGML_OP_PAD:
|
|
8943
9199
|
ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
|
|
8944
9200
|
|
|
9201
|
+
break;
|
|
9202
|
+
case GGML_OP_ROLL:
|
|
9203
|
+
ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
|
|
9204
|
+
|
|
8945
9205
|
break;
|
|
8946
9206
|
case GGML_OP_CPY:
|
|
8947
9207
|
case GGML_OP_CONT:
|
|
8948
9208
|
case GGML_OP_DUP:
|
|
8949
9209
|
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
|
|
8950
9210
|
|
|
9211
|
+
break;
|
|
9212
|
+
case GGML_OP_SET_ROWS:
|
|
9213
|
+
ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
9214
|
+
|
|
8951
9215
|
break;
|
|
8952
9216
|
case GGML_OP_SILU_BACK:
|
|
8953
9217
|
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -8962,8 +9226,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8962
9226
|
|
|
8963
9227
|
break;
|
|
8964
9228
|
case GGML_OP_RMS_NORM:
|
|
8965
|
-
|
|
8966
|
-
|
|
9229
|
+
if (ctx->num_additional_fused_ops > 0) {
|
|
9230
|
+
// fused rms_norm + mul
|
|
9231
|
+
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
|
9232
|
+
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
|
|
9233
|
+
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
|
|
9234
|
+
} else {
|
|
9235
|
+
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
|
|
9236
|
+
}
|
|
8967
9237
|
break;
|
|
8968
9238
|
case GGML_OP_RMS_NORM_BACK:
|
|
8969
9239
|
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -8977,6 +9247,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8977
9247
|
switch (ggml_get_unary_op(node)) {
|
|
8978
9248
|
case GGML_UNARY_OP_SILU:
|
|
8979
9249
|
case GGML_UNARY_OP_GELU:
|
|
9250
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
8980
9251
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
8981
9252
|
case GGML_UNARY_OP_RELU:
|
|
8982
9253
|
case GGML_UNARY_OP_TANH:
|
|
@@ -8987,6 +9258,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8987
9258
|
return false;
|
|
8988
9259
|
}
|
|
8989
9260
|
break;
|
|
9261
|
+
case GGML_OP_GLU:
|
|
9262
|
+
switch (ggml_get_glu_op(node)) {
|
|
9263
|
+
case GGML_GLU_OP_GEGLU:
|
|
9264
|
+
case GGML_GLU_OP_REGLU:
|
|
9265
|
+
case GGML_GLU_OP_SWIGLU:
|
|
9266
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9267
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9268
|
+
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
9269
|
+
break;
|
|
9270
|
+
default:
|
|
9271
|
+
return false;
|
|
9272
|
+
}
|
|
9273
|
+
break;
|
|
8990
9274
|
case GGML_OP_DIAG_MASK_INF:
|
|
8991
9275
|
ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
|
|
8992
9276
|
|
|
@@ -9108,12 +9392,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
9108
9392
|
|
|
9109
9393
|
ctx->compute_ctx.reset();
|
|
9110
9394
|
|
|
9111
|
-
bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
|
|
9395
|
+
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
|
|
9112
9396
|
if (!ok) {
|
|
9113
9397
|
if (node->op == GGML_OP_UNARY) {
|
|
9114
9398
|
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
|
|
9115
|
-
}
|
|
9116
|
-
|
|
9399
|
+
} else if (node->op == GGML_OP_GLU) {
|
|
9400
|
+
std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
|
|
9401
|
+
} else {
|
|
9117
9402
|
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
|
|
9118
9403
|
}
|
|
9119
9404
|
}
|
|
@@ -9122,7 +9407,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
9122
9407
|
return true;
|
|
9123
9408
|
}
|
|
9124
9409
|
|
|
9125
|
-
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
|
|
9410
|
+
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
|
|
9411
|
+
GGML_UNUSED(cgraph);
|
|
9126
9412
|
ggml_backend_buffer * buf = nullptr;
|
|
9127
9413
|
|
|
9128
9414
|
switch (tensor->op) {
|
|
@@ -9140,7 +9426,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9140
9426
|
case GGML_OP_COS:
|
|
9141
9427
|
case GGML_OP_CLAMP:
|
|
9142
9428
|
case GGML_OP_PAD:
|
|
9429
|
+
case GGML_OP_ROLL:
|
|
9143
9430
|
case GGML_OP_CPY:
|
|
9431
|
+
case GGML_OP_SET_ROWS:
|
|
9144
9432
|
case GGML_OP_CONT:
|
|
9145
9433
|
case GGML_OP_DUP:
|
|
9146
9434
|
case GGML_OP_SILU_BACK:
|
|
@@ -9182,6 +9470,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9182
9470
|
switch (ggml_get_unary_op(tensor)) {
|
|
9183
9471
|
case GGML_UNARY_OP_SILU:
|
|
9184
9472
|
case GGML_UNARY_OP_GELU:
|
|
9473
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
9185
9474
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
9186
9475
|
case GGML_UNARY_OP_RELU:
|
|
9187
9476
|
case GGML_UNARY_OP_TANH:
|
|
@@ -9192,6 +9481,19 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9192
9481
|
return false;
|
|
9193
9482
|
}
|
|
9194
9483
|
break;
|
|
9484
|
+
case GGML_OP_GLU:
|
|
9485
|
+
switch (ggml_get_glu_op(tensor)) {
|
|
9486
|
+
case GGML_GLU_OP_GEGLU:
|
|
9487
|
+
case GGML_GLU_OP_REGLU:
|
|
9488
|
+
case GGML_GLU_OP_SWIGLU:
|
|
9489
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9490
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9491
|
+
buf = tensor->buffer;
|
|
9492
|
+
break;
|
|
9493
|
+
default:
|
|
9494
|
+
return false;
|
|
9495
|
+
}
|
|
9496
|
+
break;
|
|
9195
9497
|
case GGML_OP_MUL_MAT:
|
|
9196
9498
|
case GGML_OP_MUL_MAT_ID:
|
|
9197
9499
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
@@ -9218,7 +9520,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9218
9520
|
// Only run if ctx hasn't been submitted yet
|
|
9219
9521
|
if (!subctx->seqs.empty()) {
|
|
9220
9522
|
#ifdef GGML_VULKAN_CHECK_RESULTS
|
|
9221
|
-
ggml_vk_check_results_0(
|
|
9523
|
+
ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
|
|
9222
9524
|
use_fence = true;
|
|
9223
9525
|
#endif
|
|
9224
9526
|
|
|
@@ -9238,7 +9540,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9238
9540
|
ggml_vk_wait_for_fence(ctx);
|
|
9239
9541
|
}
|
|
9240
9542
|
#ifdef GGML_VULKAN_CHECK_RESULTS
|
|
9241
|
-
ggml_vk_check_results_1(
|
|
9543
|
+
ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
|
|
9242
9544
|
#endif
|
|
9243
9545
|
}
|
|
9244
9546
|
|
|
@@ -9685,6 +9987,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
|
|
|
9685
9987
|
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
|
|
9686
9988
|
}
|
|
9687
9989
|
|
|
9990
|
+
static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
|
|
9991
|
+
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
|
9992
|
+
return false;
|
|
9993
|
+
}
|
|
9994
|
+
|
|
9995
|
+
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
|
|
9996
|
+
// additional constraints specific to this fusion
|
|
9997
|
+
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
|
|
9998
|
+
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
|
9999
|
+
|
|
10000
|
+
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
|
|
10001
|
+
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
|
|
10002
|
+
// rms_norm only supports f32
|
|
10003
|
+
if (mul->src[0]->type != GGML_TYPE_F32 ||
|
|
10004
|
+
mul->src[1]->type != GGML_TYPE_F32 ||
|
|
10005
|
+
mul->type != GGML_TYPE_F32) {
|
|
10006
|
+
return false;
|
|
10007
|
+
}
|
|
10008
|
+
// if rms_norm is the B operand, then we don't handle broadcast
|
|
10009
|
+
if (rms_norm == mul->src[1] &&
|
|
10010
|
+
mul->src[0]->ne[1] != rms_norm->ne[1]) {
|
|
10011
|
+
return false;
|
|
10012
|
+
}
|
|
10013
|
+
// rms_norm shader assumes contiguous rows
|
|
10014
|
+
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
|
|
10015
|
+
return false;
|
|
10016
|
+
}
|
|
10017
|
+
}
|
|
10018
|
+
return true;
|
|
10019
|
+
}
|
|
10020
|
+
|
|
9688
10021
|
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
|
9689
10022
|
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
|
9690
10023
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
@@ -9698,10 +10031,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9698
10031
|
|
|
9699
10032
|
uint64_t total_mat_mul_bytes = 0;
|
|
9700
10033
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
9701
|
-
|
|
10034
|
+
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
10035
|
+
ctx->num_additional_fused_ops = 1;
|
|
10036
|
+
}
|
|
10037
|
+
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
|
9702
10038
|
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
|
9703
10039
|
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
9704
10040
|
}
|
|
10041
|
+
i += ctx->num_additional_fused_ops;
|
|
10042
|
+
ctx->num_additional_fused_ops = 0;
|
|
9705
10043
|
}
|
|
9706
10044
|
if (ctx->device->need_compiles) {
|
|
9707
10045
|
ggml_vk_load_shaders(ctx->device);
|
|
@@ -9763,14 +10101,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9763
10101
|
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
9764
10102
|
}
|
|
9765
10103
|
|
|
10104
|
+
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
10105
|
+
ctx->num_additional_fused_ops = 1;
|
|
10106
|
+
}
|
|
10107
|
+
|
|
9766
10108
|
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
|
|
9767
10109
|
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
|
9768
10110
|
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
|
9769
10111
|
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
|
9770
|
-
(i == last_node) ||
|
|
10112
|
+
(i + ctx->num_additional_fused_ops == last_node) ||
|
|
9771
10113
|
(almost_ready && !ctx->almost_ready_fence_pending);
|
|
9772
10114
|
|
|
9773
|
-
bool enqueued = ggml_vk_build_graph(ctx, cgraph
|
|
10115
|
+
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
|
|
9774
10116
|
|
|
9775
10117
|
if (vk_perf_logger_enabled) {
|
|
9776
10118
|
if (ctx->compute_ctx.expired()) {
|
|
@@ -9780,7 +10122,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9780
10122
|
} else {
|
|
9781
10123
|
compute_ctx = ctx->compute_ctx.lock();
|
|
9782
10124
|
}
|
|
9783
|
-
|
|
10125
|
+
// If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
|
|
10126
|
+
for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
|
|
10127
|
+
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
|
|
10128
|
+
}
|
|
9784
10129
|
}
|
|
9785
10130
|
|
|
9786
10131
|
if (enqueued) {
|
|
@@ -9802,6 +10147,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9802
10147
|
}
|
|
9803
10148
|
submit_count++;
|
|
9804
10149
|
}
|
|
10150
|
+
i += ctx->num_additional_fused_ops;
|
|
10151
|
+
ctx->num_additional_fused_ops = 0;
|
|
9805
10152
|
}
|
|
9806
10153
|
|
|
9807
10154
|
if (vk_perf_logger_enabled) {
|
|
@@ -9963,6 +10310,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9963
10310
|
case GGML_OP_UNARY:
|
|
9964
10311
|
switch (ggml_get_unary_op(op)) {
|
|
9965
10312
|
case GGML_UNARY_OP_GELU:
|
|
10313
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
9966
10314
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
9967
10315
|
case GGML_UNARY_OP_SILU:
|
|
9968
10316
|
case GGML_UNARY_OP_RELU:
|
|
@@ -9976,15 +10324,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9976
10324
|
return false;
|
|
9977
10325
|
}
|
|
9978
10326
|
break;
|
|
10327
|
+
case GGML_OP_GLU:
|
|
10328
|
+
switch (ggml_get_glu_op(op)) {
|
|
10329
|
+
case GGML_GLU_OP_GEGLU:
|
|
10330
|
+
case GGML_GLU_OP_REGLU:
|
|
10331
|
+
case GGML_GLU_OP_SWIGLU:
|
|
10332
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
10333
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
10334
|
+
return ggml_is_contiguous(op->src[0]) &&
|
|
10335
|
+
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
|
10336
|
+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
|
10337
|
+
(op->src[0]->type == op->type);
|
|
10338
|
+
default:
|
|
10339
|
+
return false;
|
|
10340
|
+
}
|
|
10341
|
+
break;
|
|
9979
10342
|
case GGML_OP_MUL_MAT:
|
|
9980
10343
|
case GGML_OP_MUL_MAT_ID:
|
|
9981
10344
|
{
|
|
9982
10345
|
ggml_type src0_type = op->src[0]->type;
|
|
9983
10346
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
9984
10347
|
const vk_device& device = ggml_vk_get_device(ctx->device);
|
|
9985
|
-
if (op->op == GGML_OP_MUL_MAT_ID
|
|
9986
|
-
|
|
9987
|
-
|
|
10348
|
+
if (op->op == GGML_OP_MUL_MAT_ID) {
|
|
10349
|
+
if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
|
|
10350
|
+
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
|
10351
|
+
return false;
|
|
10352
|
+
}
|
|
10353
|
+
// Check against size of shared memory variable
|
|
10354
|
+
if (op->src[2]->ne[0] > 4096) {
|
|
10355
|
+
return false;
|
|
10356
|
+
}
|
|
9988
10357
|
}
|
|
9989
10358
|
switch (src0_type) {
|
|
9990
10359
|
case GGML_TYPE_F32:
|
|
@@ -10042,19 +10411,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10042
10411
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
10043
10412
|
auto device = ggml_vk_get_device(ctx->device);
|
|
10044
10413
|
bool coopmat2 = device->coopmat2;
|
|
10045
|
-
|
|
10046
|
-
|
|
10047
|
-
case 80:
|
|
10048
|
-
case 96:
|
|
10049
|
-
case 112:
|
|
10050
|
-
case 128:
|
|
10051
|
-
case 256:
|
|
10052
|
-
break;
|
|
10053
|
-
default:
|
|
10054
|
-
return false;
|
|
10055
|
-
}
|
|
10056
|
-
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
|
10057
|
-
// different head sizes of K and V are not supported yet
|
|
10414
|
+
FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
|
|
10415
|
+
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
|
|
10058
10416
|
return false;
|
|
10059
10417
|
}
|
|
10060
10418
|
if (op->src[0]->type != GGML_TYPE_F32) {
|
|
@@ -10134,6 +10492,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10134
10492
|
return false;
|
|
10135
10493
|
}
|
|
10136
10494
|
} break;
|
|
10495
|
+
case GGML_OP_SET_ROWS:
|
|
10496
|
+
{
|
|
10497
|
+
switch (op->type) {
|
|
10498
|
+
case GGML_TYPE_F32:
|
|
10499
|
+
case GGML_TYPE_F16:
|
|
10500
|
+
case GGML_TYPE_BF16:
|
|
10501
|
+
case GGML_TYPE_Q4_0:
|
|
10502
|
+
case GGML_TYPE_Q4_1:
|
|
10503
|
+
case GGML_TYPE_Q5_0:
|
|
10504
|
+
case GGML_TYPE_Q5_1:
|
|
10505
|
+
case GGML_TYPE_Q8_0:
|
|
10506
|
+
case GGML_TYPE_IQ4_NL:
|
|
10507
|
+
return true;
|
|
10508
|
+
default:
|
|
10509
|
+
return false;
|
|
10510
|
+
}
|
|
10511
|
+
} break;
|
|
10137
10512
|
case GGML_OP_CONT:
|
|
10138
10513
|
case GGML_OP_CPY:
|
|
10139
10514
|
case GGML_OP_DUP:
|
|
@@ -10218,11 +10593,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10218
10593
|
case GGML_OP_CLAMP:
|
|
10219
10594
|
return op->src[0]->type == GGML_TYPE_F32;
|
|
10220
10595
|
case GGML_OP_UPSCALE:
|
|
10221
|
-
return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
|
10222
10596
|
case GGML_OP_ACC:
|
|
10223
10597
|
case GGML_OP_CONCAT:
|
|
10224
10598
|
case GGML_OP_SCALE:
|
|
10225
10599
|
case GGML_OP_PAD:
|
|
10600
|
+
case GGML_OP_ROLL:
|
|
10226
10601
|
case GGML_OP_DIAG_MASK_INF:
|
|
10227
10602
|
case GGML_OP_SOFT_MAX:
|
|
10228
10603
|
case GGML_OP_SOFT_MAX_BACK:
|
|
@@ -10513,11 +10888,21 @@ void * comp_result;
|
|
|
10513
10888
|
size_t comp_size;
|
|
10514
10889
|
size_t comp_nb[GGML_MAX_DIMS];
|
|
10515
10890
|
size_t check_counter = 0;
|
|
10516
|
-
static void ggml_vk_check_results_0(
|
|
10891
|
+
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
|
|
10892
|
+
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
|
|
10517
10893
|
if (tensor->op == GGML_OP_TRANSPOSE) {
|
|
10518
10894
|
return;
|
|
10519
10895
|
}
|
|
10520
10896
|
|
|
10897
|
+
bool fused_rms_norm_mul = false;
|
|
10898
|
+
int rms_norm_idx = -1;
|
|
10899
|
+
if (ctx->num_additional_fused_ops == 1 &&
|
|
10900
|
+
tensor->op == GGML_OP_RMS_NORM &&
|
|
10901
|
+
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
|
|
10902
|
+
fused_rms_norm_mul = true;
|
|
10903
|
+
tensor = cgraph->nodes[tensor_idx + 1];
|
|
10904
|
+
}
|
|
10905
|
+
|
|
10521
10906
|
check_counter++;
|
|
10522
10907
|
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
|
|
10523
10908
|
return;
|
|
@@ -10545,6 +10930,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10545
10930
|
|
|
10546
10931
|
for (int i = 0; i < 6; i++) {
|
|
10547
10932
|
ggml_tensor * srci = tensor->src[i];
|
|
10933
|
+
if (fused_rms_norm_mul) {
|
|
10934
|
+
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
|
|
10935
|
+
ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
|
|
10936
|
+
switch (i) {
|
|
10937
|
+
case 0: srci = rms_norm->src[0]; break;
|
|
10938
|
+
case 1: srci = tensor->src[1 - rms_norm_idx]; break;
|
|
10939
|
+
default: continue;
|
|
10940
|
+
}
|
|
10941
|
+
}
|
|
10548
10942
|
if (srci == nullptr) {
|
|
10549
10943
|
continue;
|
|
10550
10944
|
}
|
|
@@ -10602,7 +10996,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10602
10996
|
} else if (tensor->op == GGML_OP_SUB) {
|
|
10603
10997
|
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
|
|
10604
10998
|
} else if (tensor->op == GGML_OP_MUL) {
|
|
10605
|
-
|
|
10999
|
+
if (fused_rms_norm_mul) {
|
|
11000
|
+
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
|
|
11001
|
+
tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
|
|
11002
|
+
} else {
|
|
11003
|
+
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
|
|
11004
|
+
}
|
|
10606
11005
|
} else if (tensor->op == GGML_OP_DIV) {
|
|
10607
11006
|
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
|
|
10608
11007
|
} else if (tensor->op == GGML_OP_CONCAT) {
|
|
@@ -10690,6 +11089,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10690
11089
|
case GGML_UNARY_OP_GELU:
|
|
10691
11090
|
tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
|
|
10692
11091
|
break;
|
|
11092
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
11093
|
+
tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
|
|
11094
|
+
break;
|
|
10693
11095
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
10694
11096
|
tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
|
|
10695
11097
|
break;
|
|
@@ -10706,6 +11108,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10706
11108
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
|
10707
11109
|
GGML_ABORT("fatal error");
|
|
10708
11110
|
}
|
|
11111
|
+
} else if (tensor->op == GGML_OP_GLU) {
|
|
11112
|
+
if (src_clone[1] == nullptr) {
|
|
11113
|
+
tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
|
|
11114
|
+
} else {
|
|
11115
|
+
tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
|
|
11116
|
+
}
|
|
10709
11117
|
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
|
|
10710
11118
|
if (src1 == nullptr) {
|
|
10711
11119
|
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
|
|
@@ -10713,6 +11121,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10713
11121
|
} else {
|
|
10714
11122
|
tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
|
|
10715
11123
|
}
|
|
11124
|
+
} else if (tensor->op == GGML_OP_SET_ROWS) {
|
|
11125
|
+
tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
|
|
10716
11126
|
} else if (tensor->op == GGML_OP_CONT) {
|
|
10717
11127
|
tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
|
10718
11128
|
} else if (tensor->op == GGML_OP_RESHAPE) {
|
|
@@ -10784,10 +11194,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10784
11194
|
GGML_ABORT("fatal error");
|
|
10785
11195
|
}
|
|
10786
11196
|
|
|
10787
|
-
ggml_cgraph *
|
|
10788
|
-
ggml_build_forward_expand(
|
|
11197
|
+
ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
|
|
11198
|
+
ggml_build_forward_expand(cgraph_cpu, tensor_clone);
|
|
10789
11199
|
|
|
10790
|
-
ggml_graph_compute_with_ctx(ggml_ctx,
|
|
11200
|
+
ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
|
|
10791
11201
|
|
|
10792
11202
|
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
|
|
10793
11203
|
ggml_vk_print_tensor(tensor_clone, "tensor_clone");
|
|
@@ -10810,10 +11220,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10810
11220
|
VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
|
|
10811
11221
|
}
|
|
10812
11222
|
|
|
10813
|
-
static void ggml_vk_check_results_1(
|
|
11223
|
+
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
|
|
11224
|
+
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
|
|
10814
11225
|
if (tensor->op == GGML_OP_TRANSPOSE) {
|
|
10815
11226
|
return;
|
|
10816
11227
|
}
|
|
11228
|
+
bool fused_rms_norm_mul = false;
|
|
11229
|
+
if (ctx->num_additional_fused_ops == 1 &&
|
|
11230
|
+
tensor->op == GGML_OP_RMS_NORM &&
|
|
11231
|
+
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
|
|
11232
|
+
fused_rms_norm_mul = true;
|
|
11233
|
+
tensor = cgraph->nodes[tensor_idx + 1];
|
|
11234
|
+
}
|
|
11235
|
+
|
|
10817
11236
|
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
|
|
10818
11237
|
return;
|
|
10819
11238
|
}
|