@novastera-oss/llamarn 0.2.9 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +8 -8
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +62 -1
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +22 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +15 -47
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
- package/cpp/llama.cpp/src/llama-arch.h +23 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
- package/cpp/llama.cpp/src/llama-batch.h +31 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
- package/cpp/llama.cpp/src/llama-graph.h +184 -122
- package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
- package/cpp/llama.cpp/src/llama-hparams.h +13 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
- package/cpp/llama.cpp/src/llama-model.h +21 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
- package/cpp/llama.cpp/src/llama-vocab.h +43 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +22 -4
- package/ios/include/llama.h +15 -47
- package/ios/libs/llama.xcframework/Info.plist +13 -13
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -224,6 +224,21 @@ enum vk_device_architecture {
|
|
|
224
224
|
INTEL_XE2,
|
|
225
225
|
};
|
|
226
226
|
|
|
227
|
+
// HSK x HSV
|
|
228
|
+
enum FaHeadSizes {
|
|
229
|
+
FA_HEAD_SIZE_64,
|
|
230
|
+
FA_HEAD_SIZE_80,
|
|
231
|
+
FA_HEAD_SIZE_96,
|
|
232
|
+
FA_HEAD_SIZE_112,
|
|
233
|
+
FA_HEAD_SIZE_128,
|
|
234
|
+
FA_HEAD_SIZE_192,
|
|
235
|
+
FA_HEAD_SIZE_192_128,
|
|
236
|
+
FA_HEAD_SIZE_256,
|
|
237
|
+
FA_HEAD_SIZE_576_512,
|
|
238
|
+
FA_HEAD_SIZE_UNSUPPORTED,
|
|
239
|
+
FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
|
|
240
|
+
};
|
|
241
|
+
|
|
227
242
|
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
|
228
243
|
vk::PhysicalDeviceProperties props = device.getProperties();
|
|
229
244
|
|
|
@@ -305,7 +320,7 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
|
|
|
305
320
|
}
|
|
306
321
|
|
|
307
322
|
struct vk_device_struct {
|
|
308
|
-
std::
|
|
323
|
+
std::recursive_mutex mutex;
|
|
309
324
|
|
|
310
325
|
vk::PhysicalDevice physical_device;
|
|
311
326
|
vk::PhysicalDeviceProperties properties;
|
|
@@ -313,6 +328,7 @@ struct vk_device_struct {
|
|
|
313
328
|
uint64_t max_memory_allocation_size;
|
|
314
329
|
uint64_t suballocation_block_size;
|
|
315
330
|
bool fp16;
|
|
331
|
+
bool bf16;
|
|
316
332
|
bool pipeline_robustness;
|
|
317
333
|
vk::Device device;
|
|
318
334
|
uint32_t vendor_id;
|
|
@@ -410,32 +426,42 @@ struct vk_device_struct {
|
|
|
410
426
|
vk_pipeline pipeline_div_norepeat[2][2][2];
|
|
411
427
|
|
|
412
428
|
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
|
413
|
-
vk_pipeline
|
|
429
|
+
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
|
|
414
430
|
vk_pipeline pipeline_scale_f32;
|
|
415
431
|
vk_pipeline pipeline_sqr_f32;
|
|
416
432
|
vk_pipeline pipeline_sin_f32;
|
|
417
433
|
vk_pipeline pipeline_cos_f32;
|
|
418
434
|
vk_pipeline pipeline_clamp_f32;
|
|
419
435
|
vk_pipeline pipeline_pad_f32;
|
|
436
|
+
vk_pipeline pipeline_roll_f32;
|
|
420
437
|
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
|
421
438
|
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
|
|
422
439
|
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
|
|
423
440
|
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
|
424
441
|
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
|
442
|
+
vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
|
|
425
443
|
vk_pipeline pipeline_norm_f32;
|
|
426
444
|
vk_pipeline pipeline_group_norm_f32;
|
|
427
445
|
vk_pipeline pipeline_rms_norm_f32;
|
|
446
|
+
vk_pipeline pipeline_rms_norm_mul_f32;
|
|
428
447
|
vk_pipeline pipeline_rms_norm_back_f32;
|
|
429
448
|
vk_pipeline pipeline_l2_norm_f32;
|
|
430
449
|
|
|
431
450
|
// [src/dst 0=fp32,1=fp16]
|
|
432
451
|
vk_pipeline pipeline_gelu[2];
|
|
452
|
+
vk_pipeline pipeline_gelu_erf[2];
|
|
433
453
|
vk_pipeline pipeline_gelu_quick[2];
|
|
434
454
|
vk_pipeline pipeline_silu[2];
|
|
435
455
|
vk_pipeline pipeline_relu[2];
|
|
436
456
|
vk_pipeline pipeline_tanh[2];
|
|
437
457
|
vk_pipeline pipeline_sigmoid[2];
|
|
438
458
|
|
|
459
|
+
vk_pipeline pipeline_geglu[2];
|
|
460
|
+
vk_pipeline pipeline_reglu[2];
|
|
461
|
+
vk_pipeline pipeline_swiglu[2];
|
|
462
|
+
vk_pipeline pipeline_geglu_erf[2];
|
|
463
|
+
vk_pipeline pipeline_geglu_quick[2];
|
|
464
|
+
|
|
439
465
|
vk_pipeline pipeline_leaky_relu_f32;
|
|
440
466
|
vk_pipeline pipeline_silu_back_f32;
|
|
441
467
|
vk_pipeline pipeline_diag_mask_inf_f32;
|
|
@@ -457,30 +483,16 @@ struct vk_device_struct {
|
|
|
457
483
|
vk_pipeline pipeline_rwkv_wkv6_f32;
|
|
458
484
|
vk_pipeline pipeline_rwkv_wkv7_f32;
|
|
459
485
|
vk_pipeline pipeline_opt_step_adamw_f32;
|
|
486
|
+
vk_pipeline pipeline_conv2d_f32;
|
|
460
487
|
vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
|
461
488
|
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
|
462
489
|
|
|
463
490
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
|
464
|
-
vk_pipeline
|
|
465
|
-
|
|
466
|
-
vk_pipeline
|
|
467
|
-
|
|
468
|
-
vk_pipeline
|
|
469
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
470
|
-
|
|
471
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
472
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
473
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
474
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
475
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
476
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
477
|
-
|
|
478
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
|
479
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
|
480
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
|
481
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
|
482
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
|
483
|
-
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
|
491
|
+
vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
|
492
|
+
|
|
493
|
+
vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
|
494
|
+
|
|
495
|
+
vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
|
484
496
|
|
|
485
497
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
|
486
498
|
|
|
@@ -493,6 +505,8 @@ struct vk_device_struct {
|
|
|
493
505
|
|
|
494
506
|
ggml_backend_buffer_type buffer_type;
|
|
495
507
|
|
|
508
|
+
bool disable_fusion;
|
|
509
|
+
|
|
496
510
|
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
|
497
511
|
std::unique_ptr<vk_memory_logger> memory_logger;
|
|
498
512
|
#endif
|
|
@@ -627,6 +641,8 @@ struct vk_flash_attn_push_constants {
|
|
|
627
641
|
uint32_t nev2;
|
|
628
642
|
uint32_t nev3;
|
|
629
643
|
uint32_t nem1;
|
|
644
|
+
uint32_t nem2;
|
|
645
|
+
uint32_t nem3;
|
|
630
646
|
|
|
631
647
|
uint32_t nb01;
|
|
632
648
|
uint32_t nb02;
|
|
@@ -637,14 +653,12 @@ struct vk_flash_attn_push_constants {
|
|
|
637
653
|
uint32_t nb21;
|
|
638
654
|
uint32_t nb22;
|
|
639
655
|
uint32_t nb23;
|
|
640
|
-
uint32_t nb31;
|
|
641
656
|
|
|
642
657
|
float scale;
|
|
643
658
|
float max_bias;
|
|
644
659
|
float logit_softcap;
|
|
645
660
|
|
|
646
|
-
uint32_t
|
|
647
|
-
uint32_t n_head_log2;
|
|
661
|
+
uint32_t mask_n_head_log2;
|
|
648
662
|
float m0;
|
|
649
663
|
float m1;
|
|
650
664
|
|
|
@@ -652,6 +666,7 @@ struct vk_flash_attn_push_constants {
|
|
|
652
666
|
uint32_t split_kv;
|
|
653
667
|
uint32_t k_num;
|
|
654
668
|
};
|
|
669
|
+
static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
|
|
655
670
|
|
|
656
671
|
struct vk_op_push_constants {
|
|
657
672
|
uint32_t KX;
|
|
@@ -660,6 +675,13 @@ struct vk_op_push_constants {
|
|
|
660
675
|
float param2;
|
|
661
676
|
};
|
|
662
677
|
|
|
678
|
+
struct vk_op_glu_push_constants {
|
|
679
|
+
uint32_t N;
|
|
680
|
+
uint32_t ne00;
|
|
681
|
+
uint32_t ne20;
|
|
682
|
+
uint32_t mode; // 0: default, 1: swapped, 2: split
|
|
683
|
+
};
|
|
684
|
+
|
|
663
685
|
struct vk_op_unary_push_constants {
|
|
664
686
|
uint32_t ne;
|
|
665
687
|
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
|
@@ -675,6 +697,37 @@ struct vk_op_unary_push_constants {
|
|
|
675
697
|
};
|
|
676
698
|
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
|
|
677
699
|
|
|
700
|
+
static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
|
|
701
|
+
GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
|
|
702
|
+
ne = ne != 0 ? ne : ggml_nelements(dst);
|
|
703
|
+
GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
|
|
704
|
+
|
|
705
|
+
vk_op_unary_push_constants p{};
|
|
706
|
+
p.ne = (uint32_t)ne;
|
|
707
|
+
|
|
708
|
+
size_t src0_tsize = ggml_type_size(src0->type);
|
|
709
|
+
p.ne00 = (uint32_t)src0->ne[0];
|
|
710
|
+
p.ne01 = (uint32_t)src0->ne[1];
|
|
711
|
+
p.ne02 = (uint32_t)src0->ne[2];
|
|
712
|
+
p.ne03 = (uint32_t)src0->ne[3];
|
|
713
|
+
p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
|
|
714
|
+
p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
|
|
715
|
+
p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
|
|
716
|
+
p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
|
|
717
|
+
|
|
718
|
+
size_t dst_tsize = ggml_type_size(dst->type);
|
|
719
|
+
p.ne10 = (uint32_t)dst->ne[0];
|
|
720
|
+
p.ne11 = (uint32_t)dst->ne[1];
|
|
721
|
+
p.ne12 = (uint32_t)dst->ne[2];
|
|
722
|
+
p.ne13 = (uint32_t)dst->ne[3];
|
|
723
|
+
p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
|
|
724
|
+
p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
|
|
725
|
+
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
|
|
726
|
+
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
|
|
727
|
+
|
|
728
|
+
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
|
|
729
|
+
}
|
|
730
|
+
|
|
678
731
|
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
|
679
732
|
// Precompute mp (m' in the paper) and L such that division
|
|
680
733
|
// can be computed using a multiply (high 32b of 64b result)
|
|
@@ -743,6 +796,14 @@ struct vk_op_rope_push_constants {
|
|
|
743
796
|
struct vk_op_soft_max_push_constants {
|
|
744
797
|
uint32_t KX;
|
|
745
798
|
uint32_t KY;
|
|
799
|
+
uint32_t ne00;
|
|
800
|
+
uint32_t ne01;
|
|
801
|
+
uint32_t ne02;
|
|
802
|
+
uint32_t ne12;
|
|
803
|
+
uint32_t ne13;
|
|
804
|
+
uint32_t nb11;
|
|
805
|
+
uint32_t nb12;
|
|
806
|
+
uint32_t nb13;
|
|
746
807
|
float scale;
|
|
747
808
|
float max_bias;
|
|
748
809
|
float m0;
|
|
@@ -816,6 +877,38 @@ struct vk_op_rwkv_wkv7_push_constants {
|
|
|
816
877
|
uint32_t H;
|
|
817
878
|
};
|
|
818
879
|
|
|
880
|
+
struct vk_op_conv2d_push_constants {
|
|
881
|
+
uint32_t Cout;
|
|
882
|
+
uint32_t Cin;
|
|
883
|
+
uint32_t N;
|
|
884
|
+
|
|
885
|
+
uint32_t KW;
|
|
886
|
+
uint32_t KH;
|
|
887
|
+
uint32_t W;
|
|
888
|
+
uint32_t H;
|
|
889
|
+
uint32_t OW;
|
|
890
|
+
uint32_t OH;
|
|
891
|
+
|
|
892
|
+
uint32_t s0;
|
|
893
|
+
uint32_t s1;
|
|
894
|
+
uint32_t p0;
|
|
895
|
+
uint32_t p1;
|
|
896
|
+
uint32_t d0;
|
|
897
|
+
uint32_t d1;
|
|
898
|
+
|
|
899
|
+
uint32_t nb01;
|
|
900
|
+
uint32_t nb02;
|
|
901
|
+
uint32_t nb03;
|
|
902
|
+
|
|
903
|
+
uint32_t nb11;
|
|
904
|
+
uint32_t nb12;
|
|
905
|
+
uint32_t nb13;
|
|
906
|
+
|
|
907
|
+
uint32_t nb1;
|
|
908
|
+
uint32_t nb2;
|
|
909
|
+
uint32_t nb3;
|
|
910
|
+
};
|
|
911
|
+
|
|
819
912
|
struct vk_op_conv2d_dw_push_constants {
|
|
820
913
|
uint32_t ne;
|
|
821
914
|
uint32_t batches;
|
|
@@ -836,6 +929,7 @@ struct vk_op_conv2d_dw_push_constants {
|
|
|
836
929
|
|
|
837
930
|
struct vk_op_upscale_push_constants {
|
|
838
931
|
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
|
|
932
|
+
uint32_t ne00; uint32_t ne01;
|
|
839
933
|
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
|
840
934
|
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
|
|
841
935
|
float sf0; float sf1; float sf2; float sf3;
|
|
@@ -914,18 +1008,45 @@ private:
|
|
|
914
1008
|
#endif // GGML_VULKAN_MEMORY_DEBUG
|
|
915
1009
|
|
|
916
1010
|
class vk_perf_logger {
|
|
917
|
-
public:
|
|
1011
|
+
public:
|
|
918
1012
|
void print_timings() {
|
|
1013
|
+
if (timings.empty()) {
|
|
1014
|
+
return;
|
|
1015
|
+
}
|
|
1016
|
+
uint64_t total_all_op_times = 0;
|
|
919
1017
|
std::cerr << "----------------\nVulkan Timings:" << std::endl;
|
|
920
|
-
for (const auto& t : timings) {
|
|
921
|
-
uint64_t
|
|
922
|
-
for (const auto& time : t.second) {
|
|
923
|
-
|
|
1018
|
+
for (const auto & t : timings) {
|
|
1019
|
+
uint64_t total_op_times = 0;
|
|
1020
|
+
for (const auto & time : t.second) {
|
|
1021
|
+
total_op_times += time;
|
|
1022
|
+
}
|
|
1023
|
+
std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
|
|
1024
|
+
<< " us";
|
|
1025
|
+
|
|
1026
|
+
// If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
|
|
1027
|
+
auto it = flops.find(t.first);
|
|
1028
|
+
if (it != flops.end() && (it->second).size() == t.second.size()) {
|
|
1029
|
+
uint64_t total_op_flops = 0;
|
|
1030
|
+
for (const auto & elem : it->second) {
|
|
1031
|
+
total_op_flops += elem;
|
|
1032
|
+
}
|
|
1033
|
+
std::cerr << " ("
|
|
1034
|
+
<< (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) /
|
|
1035
|
+
(double(total_op_times) / (1000.0 * 1000.0 * 1000.0))
|
|
1036
|
+
<< " GFLOPS/s)";
|
|
924
1037
|
}
|
|
925
|
-
|
|
1038
|
+
|
|
1039
|
+
total_all_op_times += total_op_times;
|
|
1040
|
+
|
|
1041
|
+
std::cerr << std::endl;
|
|
1042
|
+
}
|
|
1043
|
+
|
|
1044
|
+
if (timings.size() > 0) {
|
|
1045
|
+
std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl;
|
|
926
1046
|
}
|
|
927
1047
|
|
|
928
1048
|
timings.clear();
|
|
1049
|
+
flops.clear();
|
|
929
1050
|
}
|
|
930
1051
|
|
|
931
1052
|
void log_timing(const ggml_tensor * node, uint64_t time) {
|
|
@@ -934,22 +1055,45 @@ public:
|
|
|
934
1055
|
return;
|
|
935
1056
|
}
|
|
936
1057
|
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
|
|
937
|
-
const uint64_t m
|
|
938
|
-
const uint64_t n
|
|
939
|
-
const uint64_t k
|
|
940
|
-
std::string
|
|
1058
|
+
const uint64_t m = node->src[0]->ne[1];
|
|
1059
|
+
const uint64_t n = node->src[1]->ne[1];
|
|
1060
|
+
const uint64_t k = node->src[1]->ne[0];
|
|
1061
|
+
std::string name = ggml_op_name(node->op);
|
|
941
1062
|
if (n == 1) {
|
|
942
1063
|
name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
|
|
943
1064
|
} else {
|
|
944
1065
|
name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
|
|
945
1066
|
}
|
|
946
1067
|
timings[name].push_back(time);
|
|
1068
|
+
flops[name].push_back(m * n * (k + (k - 1)));
|
|
1069
|
+
return;
|
|
1070
|
+
}
|
|
1071
|
+
if (node->op == GGML_OP_CONV_2D) {
|
|
1072
|
+
std::string name = ggml_op_name(node->op);
|
|
1073
|
+
ggml_tensor * knl = node->src[0];
|
|
1074
|
+
uint64_t OW = node->ne[0];
|
|
1075
|
+
uint64_t OH = node->ne[1];
|
|
1076
|
+
uint64_t N = node->ne[3];
|
|
1077
|
+
uint64_t Cout = node->ne[2];
|
|
1078
|
+
uint64_t KW = knl->ne[0];
|
|
1079
|
+
uint64_t KH = knl->ne[1];
|
|
1080
|
+
uint64_t Cin = knl->ne[2];
|
|
1081
|
+
// KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
|
|
1082
|
+
uint64_t size_M = Cout;
|
|
1083
|
+
uint64_t size_K = Cin * KW * KH;
|
|
1084
|
+
uint64_t size_N = N * OW * OH;
|
|
1085
|
+
uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1));
|
|
1086
|
+
name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
|
|
1087
|
+
", N=N*OW*OH=" + std::to_string(size_N);
|
|
1088
|
+
flops[name].push_back(n_flops);
|
|
1089
|
+
timings[name].push_back(time);
|
|
947
1090
|
return;
|
|
948
1091
|
}
|
|
949
1092
|
timings[ggml_op_name(node->op)].push_back(time);
|
|
950
1093
|
}
|
|
951
|
-
private:
|
|
1094
|
+
private:
|
|
952
1095
|
std::map<std::string, std::vector<uint64_t>> timings;
|
|
1096
|
+
std::map<std::string, std::vector<uint64_t>> flops;
|
|
953
1097
|
};
|
|
954
1098
|
|
|
955
1099
|
struct ggml_backend_vk_context {
|
|
@@ -978,6 +1122,10 @@ struct ggml_backend_vk_context {
|
|
|
978
1122
|
|
|
979
1123
|
vk_command_pool compute_cmd_pool;
|
|
980
1124
|
vk_command_pool transfer_cmd_pool;
|
|
1125
|
+
|
|
1126
|
+
// number of additional consecutive nodes that are being fused with the
|
|
1127
|
+
// node currently being processed
|
|
1128
|
+
int num_additional_fused_ops {};
|
|
981
1129
|
};
|
|
982
1130
|
|
|
983
1131
|
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
|
@@ -1063,8 +1211,8 @@ static size_t vk_skip_checks;
|
|
|
1063
1211
|
static size_t vk_output_tensor;
|
|
1064
1212
|
|
|
1065
1213
|
static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
|
|
1066
|
-
static void ggml_vk_check_results_0(
|
|
1067
|
-
static void ggml_vk_check_results_1(
|
|
1214
|
+
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
|
|
1215
|
+
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
|
|
1068
1216
|
#endif
|
|
1069
1217
|
|
|
1070
1218
|
typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
|
@@ -1197,7 +1345,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
|
|
1197
1345
|
}
|
|
1198
1346
|
|
|
1199
1347
|
{
|
|
1200
|
-
std::lock_guard<std::
|
|
1348
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
1201
1349
|
device->pipelines.insert({ pipeline->name, pipeline });
|
|
1202
1350
|
}
|
|
1203
1351
|
|
|
@@ -1411,7 +1559,7 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyPrope
|
|
|
1411
1559
|
|
|
1412
1560
|
static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
|
|
1413
1561
|
VK_LOG_DEBUG("ggml_vk_create_queue()");
|
|
1414
|
-
std::lock_guard<std::
|
|
1562
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
1415
1563
|
|
|
1416
1564
|
q.queue_family_index = queue_family_index;
|
|
1417
1565
|
q.transfer_only = transfer_only;
|
|
@@ -1673,10 +1821,46 @@ enum FaCodePath {
|
|
|
1673
1821
|
FA_COOPMAT2,
|
|
1674
1822
|
};
|
|
1675
1823
|
|
|
1824
|
+
static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
|
|
1825
|
+
if (hsk != 192 && hsk != 576 && hsk != hsv) {
|
|
1826
|
+
return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1827
|
+
}
|
|
1828
|
+
switch (hsk) {
|
|
1829
|
+
case 64: return FA_HEAD_SIZE_64;
|
|
1830
|
+
case 80: return FA_HEAD_SIZE_80;
|
|
1831
|
+
case 96: return FA_HEAD_SIZE_96;
|
|
1832
|
+
case 112: return FA_HEAD_SIZE_112;
|
|
1833
|
+
case 128: return FA_HEAD_SIZE_128;
|
|
1834
|
+
case 192:
|
|
1835
|
+
if (hsv == 192) {
|
|
1836
|
+
return FA_HEAD_SIZE_192;
|
|
1837
|
+
} else if (hsv == 128) {
|
|
1838
|
+
return FA_HEAD_SIZE_192_128;
|
|
1839
|
+
} else {
|
|
1840
|
+
return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1841
|
+
}
|
|
1842
|
+
case 256: return FA_HEAD_SIZE_256;
|
|
1843
|
+
case 576:
|
|
1844
|
+
if (hsv == 512) {
|
|
1845
|
+
return FA_HEAD_SIZE_576_512;
|
|
1846
|
+
} else {
|
|
1847
|
+
return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1848
|
+
}
|
|
1849
|
+
default: return FA_HEAD_SIZE_UNSUPPORTED;
|
|
1850
|
+
}
|
|
1851
|
+
}
|
|
1852
|
+
|
|
1676
1853
|
// number of rows/cols for flash attention shader
|
|
1677
1854
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
|
1678
1855
|
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
|
1679
|
-
|
|
1856
|
+
|
|
1857
|
+
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
|
|
1858
|
+
if (hsv >= 512) {
|
|
1859
|
+
return 2;
|
|
1860
|
+
} else {
|
|
1861
|
+
return 8;
|
|
1862
|
+
}
|
|
1863
|
+
}
|
|
1680
1864
|
|
|
1681
1865
|
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
|
1682
1866
|
// 128 threads split into four subgroups, each subgroup does 1/4
|
|
@@ -1693,14 +1877,15 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
|
|
1693
1877
|
}
|
|
1694
1878
|
}
|
|
1695
1879
|
|
|
1696
|
-
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t
|
|
1880
|
+
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
|
|
1697
1881
|
GGML_UNUSED(clamp);
|
|
1882
|
+
GGML_UNUSED(hsv);
|
|
1698
1883
|
|
|
1699
1884
|
if (path == FA_SCALAR) {
|
|
1700
1885
|
if (small_rows) {
|
|
1701
1886
|
return {scalar_flash_attention_num_small_rows, 64};
|
|
1702
1887
|
} else {
|
|
1703
|
-
return {
|
|
1888
|
+
return {get_fa_scalar_num_large_rows(hsv), 32};
|
|
1704
1889
|
}
|
|
1705
1890
|
}
|
|
1706
1891
|
|
|
@@ -1718,8 +1903,12 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
|
|
|
1718
1903
|
}
|
|
1719
1904
|
|
|
1720
1905
|
// small cols to reduce register count
|
|
1721
|
-
if (ggml_is_quantized(type) ||
|
|
1722
|
-
|
|
1906
|
+
if (ggml_is_quantized(type) || hsk >= 256) {
|
|
1907
|
+
if (hsk >= 512) {
|
|
1908
|
+
return {32, 32};
|
|
1909
|
+
} else {
|
|
1910
|
+
return {64, 32};
|
|
1911
|
+
}
|
|
1723
1912
|
}
|
|
1724
1913
|
return {64, 64};
|
|
1725
1914
|
}
|
|
@@ -1761,7 +1950,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
1761
1950
|
const uint32_t warps = warptile[0] / warptile[10];
|
|
1762
1951
|
|
|
1763
1952
|
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
|
1764
|
-
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
|
|
1953
|
+
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
|
|
1765
1954
|
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
|
1766
1955
|
|
|
1767
1956
|
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
|
@@ -1886,10 +2075,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1886
2075
|
s_mmq_wg_denoms_k = { 32, 32, 1 };
|
|
1887
2076
|
|
|
1888
2077
|
// spec constants and tile sizes for quant matmul_id
|
|
1889
|
-
l_warptile_mmqid = { 256, 128,
|
|
2078
|
+
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
|
|
1890
2079
|
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
1891
2080
|
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
1892
|
-
l_mmqid_wg_denoms = { 128,
|
|
2081
|
+
l_mmqid_wg_denoms = { 128, 128, 1 };
|
|
1893
2082
|
m_mmqid_wg_denoms = { 128, 64, 1 };
|
|
1894
2083
|
s_mmqid_wg_denoms = { 128, 64, 1 };
|
|
1895
2084
|
|
|
@@ -2007,23 +2196,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2007
2196
|
}
|
|
2008
2197
|
compile_count++;
|
|
2009
2198
|
}
|
|
2199
|
+
|
|
2010
2200
|
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
|
|
2011
2201
|
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
|
2012
2202
|
};
|
|
2013
2203
|
|
|
2014
|
-
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t
|
|
2015
|
-
return {fa_rows_cols(path,
|
|
2204
|
+
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
|
2205
|
+
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
|
|
2016
2206
|
};
|
|
2017
2207
|
|
|
2018
|
-
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t
|
|
2208
|
+
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
|
2019
2209
|
// For large number of rows, 128 invocations seems to work best.
|
|
2020
2210
|
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
|
2021
2211
|
// can't use 256 for D==80.
|
|
2022
2212
|
// For scalar, use 128 (arbitrary)
|
|
2213
|
+
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
|
|
2214
|
+
const uint32_t D = (hsk|hsv);
|
|
2023
2215
|
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
|
2024
2216
|
? scalar_flash_attention_workgroup_size
|
|
2025
2217
|
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
|
2026
|
-
auto rows_cols = fa_rows_cols(path,
|
|
2218
|
+
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
|
|
2027
2219
|
|
|
2028
2220
|
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
|
2029
2221
|
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
|
@@ -2032,26 +2224,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2032
2224
|
|
|
2033
2225
|
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
|
2034
2226
|
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
|
2035
|
-
return {wg_size, rows_cols[0], rows_cols[1],
|
|
2227
|
+
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
|
|
2036
2228
|
};
|
|
2037
2229
|
|
|
2038
|
-
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX,
|
|
2039
|
-
ggml_vk_create_pipeline(device, device->
|
|
2040
|
-
ggml_vk_create_pipeline(device, device->
|
|
2041
|
-
ggml_vk_create_pipeline(device, device->
|
|
2042
|
-
ggml_vk_create_pipeline(device, device->
|
|
2043
|
-
ggml_vk_create_pipeline(device, device->
|
|
2044
|
-
ggml_vk_create_pipeline(device, device->
|
|
2045
|
-
ggml_vk_create_pipeline(device, device->
|
|
2046
|
-
ggml_vk_create_pipeline(device, device->
|
|
2230
|
+
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
|
|
2231
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2232
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2233
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2234
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2235
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2236
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2237
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2238
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
2047
2239
|
|
|
2048
2240
|
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
|
2049
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
|
|
2050
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
|
|
2051
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
|
|
2052
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
|
|
2053
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
|
|
2054
|
-
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX,
|
|
2241
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
|
|
2242
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
|
|
2243
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
|
|
2244
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
|
|
2245
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
|
|
2246
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
|
|
2247
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
|
|
2248
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
|
|
2249
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
|
|
2055
2250
|
|
|
2056
2251
|
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
|
2057
2252
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
|
@@ -2641,7 +2836,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2641
2836
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2642
2837
|
|
|
2643
2838
|
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
|
2644
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2,
|
|
2839
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
|
2645
2840
|
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
|
2646
2841
|
|
|
2647
2842
|
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
|
@@ -2655,7 +2850,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2655
2850
|
|
|
2656
2851
|
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2657
2852
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2658
|
-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main",
|
|
2853
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
|
|
2854
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
|
|
2659
2855
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2660
2856
|
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2661
2857
|
|
|
@@ -2672,19 +2868,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2672
2868
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2673
2869
|
|
|
2674
2870
|
if (device->float_controls_rte_fp16) {
|
|
2675
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2676
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2677
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2678
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2679
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2680
|
-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {
|
|
2871
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2872
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2873
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2874
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2875
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2876
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2877
|
+
} else {
|
|
2878
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2879
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2880
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2881
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2882
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2883
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
|
|
2884
|
+
}
|
|
2885
|
+
|
|
2886
|
+
if (device->float_controls_rte_fp16) {
|
|
2887
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2888
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2889
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2890
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2891
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2892
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2893
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2894
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2895
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2681
2896
|
} else {
|
|
2682
|
-
ggml_vk_create_pipeline(device, device->
|
|
2683
|
-
ggml_vk_create_pipeline(device, device->
|
|
2684
|
-
ggml_vk_create_pipeline(device, device->
|
|
2685
|
-
ggml_vk_create_pipeline(device, device->
|
|
2686
|
-
ggml_vk_create_pipeline(device, device->
|
|
2687
|
-
ggml_vk_create_pipeline(device, device->
|
|
2897
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2898
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2899
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2900
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2901
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2902
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2903
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2904
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2905
|
+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
|
|
2688
2906
|
}
|
|
2689
2907
|
|
|
2690
2908
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
|
@@ -2702,10 +2920,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2702
2920
|
return s;
|
|
2703
2921
|
};
|
|
2704
2922
|
|
|
2923
|
+
bool rte = device->float_controls_rte_fp16;
|
|
2705
2924
|
#define CREATE_BINARY(name, namemod, spec) \
|
|
2706
2925
|
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
|
|
2707
2926
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
|
|
2708
|
-
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
|
|
2927
|
+
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
|
|
2709
2928
|
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
|
|
2710
2929
|
|
|
2711
2930
|
CREATE_BINARY(add, , {0})
|
|
@@ -2724,7 +2943,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2724
2943
|
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
2725
2944
|
ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
2726
2945
|
|
|
2727
|
-
ggml_vk_create_pipeline(device, device->
|
|
2946
|
+
ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
|
|
2947
|
+
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
|
|
2948
|
+
ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
|
|
2728
2949
|
|
|
2729
2950
|
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2730
2951
|
|
|
@@ -2736,6 +2957,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2736
2957
|
|
|
2737
2958
|
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2738
2959
|
|
|
2960
|
+
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2961
|
+
|
|
2739
2962
|
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2740
2963
|
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2741
2964
|
|
|
@@ -2744,6 +2967,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2744
2967
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2745
2968
|
|
|
2746
2969
|
CREATE_UNARY(gelu)
|
|
2970
|
+
CREATE_UNARY(gelu_erf)
|
|
2747
2971
|
CREATE_UNARY(gelu_quick)
|
|
2748
2972
|
CREATE_UNARY(silu)
|
|
2749
2973
|
CREATE_UNARY(relu)
|
|
@@ -2751,6 +2975,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2751
2975
|
CREATE_UNARY(sigmoid)
|
|
2752
2976
|
#undef CREATE_UNARY
|
|
2753
2977
|
|
|
2978
|
+
#define CREATE_GLU(name) \
|
|
2979
|
+
if (device->float_controls_rte_fp16) { \
|
|
2980
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
|
2981
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
|
2982
|
+
} else { \
|
|
2983
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
|
2984
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
|
2985
|
+
}
|
|
2986
|
+
|
|
2987
|
+
CREATE_GLU(geglu)
|
|
2988
|
+
CREATE_GLU(reglu)
|
|
2989
|
+
CREATE_GLU(swiglu)
|
|
2990
|
+
CREATE_GLU(geglu_erf)
|
|
2991
|
+
CREATE_GLU(geglu_quick)
|
|
2992
|
+
#undef CREATE_GLU
|
|
2993
|
+
|
|
2754
2994
|
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2755
2995
|
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2756
2996
|
|
|
@@ -2806,6 +3046,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2806
3046
|
|
|
2807
3047
|
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2808
3048
|
|
|
3049
|
+
// conv2d
|
|
3050
|
+
uint32_t conv2d_WG_SIZE = 256;
|
|
3051
|
+
uint32_t conv2d_BS_K = 128;
|
|
3052
|
+
uint32_t conv2d_BS_CRS = 16;
|
|
3053
|
+
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
|
|
3054
|
+
if (device->subgroup_shuffle &&
|
|
3055
|
+
device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316
|
|
3056
|
+
use_collectives = 1;
|
|
3057
|
+
conv2d_BS_CRS = std::min(
|
|
3058
|
+
device->subgroup_size,
|
|
3059
|
+
conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used.
|
|
3060
|
+
}
|
|
3061
|
+
uint32_t conv2d_BS_NPQ = 128;
|
|
3062
|
+
uint32_t conv2d_TS_K = 8;
|
|
3063
|
+
uint32_t conv2d_shmem_req =
|
|
3064
|
+
(conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float);
|
|
3065
|
+
if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
|
|
3066
|
+
conv2d_BS_CRS = 8;
|
|
3067
|
+
if (use_collectives) {
|
|
3068
|
+
conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
|
|
3069
|
+
}
|
|
3070
|
+
}
|
|
3071
|
+
|
|
3072
|
+
if (use_collectives) {
|
|
3073
|
+
ggml_vk_create_pipeline(
|
|
3074
|
+
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
|
|
3075
|
+
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
|
|
3076
|
+
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
|
|
3077
|
+
} else {
|
|
3078
|
+
ggml_vk_create_pipeline(
|
|
3079
|
+
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
|
|
3080
|
+
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
|
|
3081
|
+
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
|
|
3082
|
+
false);
|
|
3083
|
+
}
|
|
3084
|
+
|
|
2809
3085
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
2810
3086
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
2811
3087
|
|
|
@@ -3118,6 +3394,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
3118
3394
|
|
|
3119
3395
|
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
|
3120
3396
|
|
|
3397
|
+
#if defined(VK_KHR_shader_bfloat16)
|
|
3398
|
+
device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
|
|
3399
|
+
#else
|
|
3400
|
+
device->bf16 = false;
|
|
3401
|
+
#endif
|
|
3402
|
+
|
|
3121
3403
|
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
|
|
3122
3404
|
|
|
3123
3405
|
if (device->subgroup_size_control) {
|
|
@@ -3431,6 +3713,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
3431
3713
|
|
|
3432
3714
|
device->idx = idx;
|
|
3433
3715
|
|
|
3716
|
+
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
|
|
3717
|
+
|
|
3434
3718
|
return device;
|
|
3435
3719
|
}
|
|
3436
3720
|
|
|
@@ -3458,6 +3742,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
3458
3742
|
bool coopmat_support = false;
|
|
3459
3743
|
bool coopmat2_support = false;
|
|
3460
3744
|
bool integer_dot_product = false;
|
|
3745
|
+
bool bfloat16_support = false;
|
|
3461
3746
|
|
|
3462
3747
|
for (auto properties : ext_props) {
|
|
3463
3748
|
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
|
@@ -3478,6 +3763,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
3478
3763
|
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
|
3479
3764
|
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
|
3480
3765
|
integer_dot_product = true;
|
|
3766
|
+
#endif
|
|
3767
|
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
3768
|
+
} else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
|
|
3769
|
+
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
|
3770
|
+
bfloat16_support = true;
|
|
3481
3771
|
#endif
|
|
3482
3772
|
}
|
|
3483
3773
|
}
|
|
@@ -3544,10 +3834,25 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
3544
3834
|
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
|
|
3545
3835
|
}
|
|
3546
3836
|
|
|
3837
|
+
#if defined(VK_KHR_shader_bfloat16)
|
|
3838
|
+
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
|
|
3839
|
+
bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
|
|
3840
|
+
if (bfloat16_support) {
|
|
3841
|
+
last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
|
|
3842
|
+
last_struct = (VkBaseOutStructure *)&bfloat16_features;
|
|
3843
|
+
}
|
|
3844
|
+
#endif
|
|
3845
|
+
|
|
3547
3846
|
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
|
3548
3847
|
|
|
3549
3848
|
fp16 = fp16 && vk12_features.shaderFloat16;
|
|
3550
3849
|
|
|
3850
|
+
#if defined(VK_KHR_shader_bfloat16)
|
|
3851
|
+
bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
|
|
3852
|
+
#else
|
|
3853
|
+
bool bf16 = false;
|
|
3854
|
+
#endif
|
|
3855
|
+
|
|
3551
3856
|
uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
|
|
3552
3857
|
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
|
|
3553
3858
|
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
|
@@ -3565,8 +3870,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
3565
3870
|
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
|
|
3566
3871
|
|
|
3567
3872
|
std::string device_name = props2.properties.deviceName.data();
|
|
3568
|
-
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
|
|
3569
|
-
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
|
|
3873
|
+
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
|
|
3874
|
+
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,
|
|
3570
3875
|
props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
|
|
3571
3876
|
|
|
3572
3877
|
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
|
@@ -3651,7 +3956,6 @@ static void ggml_vk_instance_init() {
|
|
|
3651
3956
|
|
|
3652
3957
|
}
|
|
3653
3958
|
|
|
3654
|
-
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
|
|
3655
3959
|
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
|
3656
3960
|
|
|
3657
3961
|
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
|
@@ -4124,6 +4428,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
|
|
|
4124
4428
|
return nullptr;
|
|
4125
4429
|
}
|
|
4126
4430
|
|
|
4431
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
4127
4432
|
device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
|
|
4128
4433
|
|
|
4129
4434
|
return buf->ptr;
|
|
@@ -4134,6 +4439,8 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
|
|
|
4134
4439
|
return;
|
|
4135
4440
|
}
|
|
4136
4441
|
VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
|
|
4442
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
4443
|
+
|
|
4137
4444
|
vk_buffer buf;
|
|
4138
4445
|
size_t index;
|
|
4139
4446
|
for (size_t i = 0; i < device->pinned_memory.size(); i++) {
|
|
@@ -4156,6 +4463,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
|
|
|
4156
4463
|
}
|
|
4157
4464
|
|
|
4158
4465
|
static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
|
|
4466
|
+
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
|
4159
4467
|
buf = nullptr;
|
|
4160
4468
|
buf_offset = 0;
|
|
4161
4469
|
for (size_t i = 0; i < device->pinned_memory.size(); i++) {
|
|
@@ -4457,7 +4765,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
|
|
|
4457
4765
|
memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
|
|
4458
4766
|
}
|
|
4459
4767
|
} else {
|
|
4460
|
-
std::lock_guard<std::
|
|
4768
|
+
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
|
|
4461
4769
|
|
|
4462
4770
|
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
|
|
4463
4771
|
ggml_vk_ctx_begin(dst->device, subctx);
|
|
@@ -4548,7 +4856,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
|
|
|
4548
4856
|
|
|
4549
4857
|
memcpy(dst, (uint8_t *) src->ptr + offset, size);
|
|
4550
4858
|
} else {
|
|
4551
|
-
std::lock_guard<std::
|
|
4859
|
+
std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
|
|
4552
4860
|
|
|
4553
4861
|
vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
|
|
4554
4862
|
ggml_vk_ctx_begin(src->device, subctx);
|
|
@@ -4578,7 +4886,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds
|
|
|
4578
4886
|
|
|
4579
4887
|
static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
|
|
4580
4888
|
if (src->device == dst->device) {
|
|
4581
|
-
std::lock_guard<std::
|
|
4889
|
+
std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
|
|
4582
4890
|
VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
|
|
4583
4891
|
// Copy within the device
|
|
4584
4892
|
vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
|
|
@@ -4613,7 +4921,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t
|
|
|
4613
4921
|
static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
|
|
4614
4922
|
VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
|
|
4615
4923
|
|
|
4616
|
-
std::lock_guard<std::
|
|
4924
|
+
std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
|
|
4617
4925
|
vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
|
|
4618
4926
|
ggml_vk_ctx_begin(dst->device, subctx);
|
|
4619
4927
|
subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
|
|
@@ -4762,7 +5070,7 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
|
|
|
4762
5070
|
return
|
|
4763
5071
|
tensor->nb[0] == ggml_type_size(tensor->type) &&
|
|
4764
5072
|
tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
|
|
4765
|
-
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
|
5073
|
+
(tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]);
|
|
4766
5074
|
}
|
|
4767
5075
|
|
|
4768
5076
|
static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
|
|
@@ -4840,9 +5148,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
4840
5148
|
// type size must be exactly 2 or 4.
|
|
4841
5149
|
GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
|
|
4842
5150
|
if ((ggml_type_size(src->type) % 4) == 0) {
|
|
4843
|
-
|
|
5151
|
+
if (contig) {
|
|
5152
|
+
return ctx->device->pipeline_contig_cpy_f32_f32;
|
|
5153
|
+
} else {
|
|
5154
|
+
return ctx->device->pipeline_cpy_f32_f32;
|
|
5155
|
+
}
|
|
4844
5156
|
} else {
|
|
4845
|
-
|
|
5157
|
+
if (contig) {
|
|
5158
|
+
return ctx->device->pipeline_contig_cpy_f16_f16;
|
|
5159
|
+
} else {
|
|
5160
|
+
return ctx->device->pipeline_cpy_f16_f16;
|
|
5161
|
+
}
|
|
4846
5162
|
}
|
|
4847
5163
|
}
|
|
4848
5164
|
|
|
@@ -4903,7 +5219,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4903
5219
|
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
|
4904
5220
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
4905
5221
|
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
|
|
4906
|
-
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
|
|
5222
|
+
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
|
|
4907
5223
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
|
4908
5224
|
|
|
4909
5225
|
const uint64_t ne00 = src0->ne[0];
|
|
@@ -5131,7 +5447,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
5131
5447
|
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
|
5132
5448
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
5133
5449
|
std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)");
|
|
5134
|
-
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
|
|
5450
|
+
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
|
|
5135
5451
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
|
5136
5452
|
|
|
5137
5453
|
const uint64_t ne00 = src0->ne[0];
|
|
@@ -5732,7 +6048,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
5732
6048
|
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
|
|
5733
6049
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
5734
6050
|
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
|
|
5735
|
-
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
|
|
6051
|
+
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
|
|
5736
6052
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
|
5737
6053
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
|
5738
6054
|
|
|
@@ -5926,14 +6242,60 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5926
6242
|
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
|
5927
6243
|
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
|
|
5928
6244
|
} else {
|
|
5929
|
-
|
|
6245
|
+
// Split based on number of ids, to fit in shared memory
|
|
6246
|
+
const uint32_t nei0 = (uint32_t)src2->ne[0];
|
|
6247
|
+
const uint32_t nei1 = (uint32_t)src2->ne[1];
|
|
6248
|
+
|
|
6249
|
+
GGML_ASSERT(nei0 <= 4096);
|
|
6250
|
+
const uint32_t split_size = std::min(nei1, 4096u / nei0);
|
|
6251
|
+
|
|
6252
|
+
ggml_tensor src1_copy = *src1;
|
|
6253
|
+
ggml_tensor src2_copy = *src2;
|
|
6254
|
+
ggml_tensor dst_copy = *dst;
|
|
6255
|
+
|
|
6256
|
+
for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
|
|
6257
|
+
const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
|
|
6258
|
+
|
|
6259
|
+
src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
|
|
6260
|
+
src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
|
|
6261
|
+
dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
|
|
6262
|
+
|
|
6263
|
+
src1_copy.ne[2] = n_tokens;
|
|
6264
|
+
src2_copy.ne[1] = n_tokens;
|
|
6265
|
+
dst_copy.ne[2] = n_tokens;
|
|
6266
|
+
|
|
6267
|
+
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
|
|
6268
|
+
}
|
|
5930
6269
|
}
|
|
5931
6270
|
}
|
|
5932
6271
|
|
|
5933
|
-
static bool
|
|
6272
|
+
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
|
|
6273
|
+
// Needs to be kept up to date on shader changes
|
|
6274
|
+
GGML_UNUSED(hsv);
|
|
6275
|
+
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
|
6276
|
+
const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
|
|
6277
|
+
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
6278
|
+
|
|
6279
|
+
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
6280
|
+
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
|
|
6281
|
+
|
|
6282
|
+
const uint32_t masksh = Bc * Br * sizeof(float);
|
|
6283
|
+
|
|
6284
|
+
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
|
|
6285
|
+
|
|
6286
|
+
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
|
|
6287
|
+
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
6288
|
+
|
|
6289
|
+
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
|
6290
|
+
|
|
6291
|
+
return supported;
|
|
6292
|
+
}
|
|
6293
|
+
|
|
6294
|
+
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
|
|
5934
6295
|
// Needs to be kept up to date on shader changes
|
|
6296
|
+
GGML_UNUSED(hsv);
|
|
5935
6297
|
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
|
5936
|
-
const uint32_t Br =
|
|
6298
|
+
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
|
|
5937
6299
|
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
5938
6300
|
|
|
5939
6301
|
const uint32_t acctype = f32acc ? 4 : 2;
|
|
@@ -5942,12 +6304,12 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|
|
5942
6304
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
5943
6305
|
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
|
5944
6306
|
|
|
5945
|
-
const uint32_t Qf = Br * (
|
|
6307
|
+
const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
|
|
5946
6308
|
|
|
5947
|
-
const uint32_t sfshstride = (
|
|
6309
|
+
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
|
|
5948
6310
|
const uint32_t sfsh = Bc * sfshstride * acctype;
|
|
5949
6311
|
|
|
5950
|
-
const uint32_t kshstride =
|
|
6312
|
+
const uint32_t kshstride = hsk / 4 + 2;
|
|
5951
6313
|
const uint32_t ksh = Bc * kshstride * f16vec4;
|
|
5952
6314
|
|
|
5953
6315
|
const uint32_t slope = Br * sizeof(float);
|
|
@@ -5955,7 +6317,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|
|
5955
6317
|
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
|
5956
6318
|
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
5957
6319
|
|
|
5958
|
-
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(
|
|
6320
|
+
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
|
5959
6321
|
|
|
5960
6322
|
return supported;
|
|
5961
6323
|
}
|
|
@@ -5977,13 +6339,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5977
6339
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
5978
6340
|
|
|
5979
6341
|
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
|
5980
|
-
const uint32_t
|
|
6342
|
+
const uint32_t nem2 = mask ? mask->ne[2] : 0;
|
|
6343
|
+
const uint32_t nem3 = mask ? mask->ne[3] : 0;
|
|
5981
6344
|
|
|
5982
|
-
const uint32_t
|
|
6345
|
+
const uint32_t HSK = nek0;
|
|
6346
|
+
const uint32_t HSV = nev0;
|
|
5983
6347
|
uint32_t N = neq1;
|
|
5984
6348
|
const uint32_t KV = nek1;
|
|
5985
6349
|
|
|
5986
|
-
GGML_ASSERT(ne0 ==
|
|
6350
|
+
GGML_ASSERT(ne0 == HSV);
|
|
5987
6351
|
GGML_ASSERT(ne2 == N);
|
|
5988
6352
|
|
|
5989
6353
|
// input tensor rows must be contiguous
|
|
@@ -5991,12 +6355,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5991
6355
|
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
|
5992
6356
|
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
|
5993
6357
|
|
|
5994
|
-
GGML_ASSERT(neq0 ==
|
|
5995
|
-
GGML_ASSERT(nek0 == D);
|
|
5996
|
-
GGML_ASSERT(nev0 == D);
|
|
6358
|
+
GGML_ASSERT(neq0 == HSK);
|
|
5997
6359
|
|
|
5998
6360
|
GGML_ASSERT(neq1 == N);
|
|
5999
|
-
GGML_ASSERT(nev0 == D);
|
|
6000
6361
|
|
|
6001
6362
|
GGML_ASSERT(nev1 == nek1);
|
|
6002
6363
|
|
|
@@ -6017,7 +6378,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6017
6378
|
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
|
6018
6379
|
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
|
6019
6380
|
|
|
6020
|
-
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device,
|
|
6381
|
+
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
|
|
6021
6382
|
|
|
6022
6383
|
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
|
6023
6384
|
path = FA_SCALAR;
|
|
@@ -6037,7 +6398,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6037
6398
|
case FA_SCALAR:
|
|
6038
6399
|
case FA_COOPMAT1:
|
|
6039
6400
|
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
|
6040
|
-
max_gqa =
|
|
6401
|
+
max_gqa = get_fa_scalar_num_large_rows(HSV);
|
|
6041
6402
|
break;
|
|
6042
6403
|
case FA_COOPMAT2:
|
|
6043
6404
|
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
|
@@ -6047,7 +6408,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6047
6408
|
}
|
|
6048
6409
|
|
|
6049
6410
|
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
|
6050
|
-
qk_ratio * nek2 == neq2 && nek2 == nev2 &&
|
|
6411
|
+
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
|
|
6051
6412
|
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
|
6052
6413
|
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
|
6053
6414
|
// and change addressing calculations to index Q's dimension 2.
|
|
@@ -6070,47 +6431,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6070
6431
|
path = FA_SCALAR;
|
|
6071
6432
|
}
|
|
6072
6433
|
|
|
6434
|
+
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
|
|
6435
|
+
if (path == FA_SCALAR &&
|
|
6436
|
+
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
|
|
6437
|
+
small_rows = true;
|
|
6438
|
+
}
|
|
6439
|
+
|
|
6073
6440
|
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
|
6074
6441
|
|
|
6442
|
+
FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
|
|
6443
|
+
|
|
6075
6444
|
switch (path) {
|
|
6076
6445
|
case FA_SCALAR:
|
|
6077
|
-
|
|
6078
|
-
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
|
6079
|
-
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
|
6080
|
-
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
|
6081
|
-
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
|
6082
|
-
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
|
6083
|
-
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
|
6084
|
-
default:
|
|
6085
|
-
GGML_ASSERT(!"unsupported D value");
|
|
6086
|
-
return;
|
|
6087
|
-
}
|
|
6446
|
+
pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
|
|
6088
6447
|
break;
|
|
6089
6448
|
case FA_COOPMAT1:
|
|
6090
|
-
|
|
6091
|
-
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6092
|
-
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6093
|
-
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6094
|
-
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6095
|
-
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6096
|
-
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
|
|
6097
|
-
default:
|
|
6098
|
-
GGML_ASSERT(!"unsupported D value");
|
|
6099
|
-
return;
|
|
6100
|
-
}
|
|
6449
|
+
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
|
|
6101
6450
|
break;
|
|
6102
6451
|
case FA_COOPMAT2:
|
|
6103
|
-
|
|
6104
|
-
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6105
|
-
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6106
|
-
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6107
|
-
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6108
|
-
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6109
|
-
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
|
|
6110
|
-
default:
|
|
6111
|
-
GGML_ASSERT(!"unsupported D value");
|
|
6112
|
-
return;
|
|
6113
|
-
}
|
|
6452
|
+
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
|
|
6114
6453
|
break;
|
|
6115
6454
|
default:
|
|
6116
6455
|
GGML_ASSERT(0);
|
|
@@ -6138,21 +6477,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6138
6477
|
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
|
6139
6478
|
|
|
6140
6479
|
// Try to use split_k when KV is large enough to be worth the overhead
|
|
6141
|
-
if (workgroups_x == 1 && shader_core_count > 0
|
|
6480
|
+
if (workgroups_x == 1 && shader_core_count > 0) {
|
|
6142
6481
|
// Try to run two workgroups per SM.
|
|
6143
|
-
split_k =
|
|
6482
|
+
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
|
6144
6483
|
if (split_k > 1) {
|
|
6145
6484
|
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
|
6146
6485
|
// of "align", so recompute split_k based on that.
|
|
6147
|
-
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
|
|
6486
|
+
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
|
|
6148
6487
|
split_k = CEIL_DIV(KV, split_kv);
|
|
6149
6488
|
workgroups_x = split_k;
|
|
6150
6489
|
}
|
|
6151
6490
|
}
|
|
6152
6491
|
|
|
6153
|
-
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
|
|
6154
|
-
// and the per-row m and L values (ne1 rows).
|
|
6155
|
-
const uint64_t split_k_size = split_k > 1 ? (
|
|
6492
|
+
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
|
6493
|
+
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
|
6494
|
+
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
|
6156
6495
|
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
|
6157
6496
|
GGML_ABORT("Requested preallocation size is too large");
|
|
6158
6497
|
}
|
|
@@ -6239,18 +6578,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6239
6578
|
}
|
|
6240
6579
|
}
|
|
6241
6580
|
|
|
6581
|
+
uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
|
|
6582
|
+
|
|
6242
6583
|
const vk_flash_attn_push_constants pc = { N, KV,
|
|
6243
6584
|
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
|
|
6244
6585
|
(uint32_t)neq2, (uint32_t)neq3,
|
|
6245
6586
|
(uint32_t)nek2, (uint32_t)nek3,
|
|
6246
6587
|
(uint32_t)nev2, (uint32_t)nev3,
|
|
6247
|
-
nem1,
|
|
6588
|
+
nem1, nem2, nem3,
|
|
6248
6589
|
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
|
|
6249
6590
|
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
|
|
6250
6591
|
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
|
6251
|
-
nbm1,
|
|
6252
6592
|
scale, max_bias, logit_softcap,
|
|
6253
|
-
|
|
6593
|
+
mask_n_head_log2, m0, m1,
|
|
6254
6594
|
gqa_ratio, split_kv, split_k };
|
|
6255
6595
|
|
|
6256
6596
|
ggml_vk_sync_buffers(subctx);
|
|
@@ -6271,13 +6611,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6271
6611
|
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
|
6272
6612
|
|
|
6273
6613
|
ggml_vk_sync_buffers(subctx);
|
|
6274
|
-
const std::array<uint32_t,
|
|
6614
|
+
const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
|
|
6275
6615
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
|
6276
6616
|
{
|
|
6277
6617
|
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
|
6278
6618
|
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
|
6279
6619
|
},
|
|
6280
|
-
pc2, { (uint32_t)ne1,
|
|
6620
|
+
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
|
6281
6621
|
} else {
|
|
6282
6622
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
6283
6623
|
{
|
|
@@ -6353,8 +6693,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6353
6693
|
}
|
|
6354
6694
|
return nullptr;
|
|
6355
6695
|
case GGML_OP_UPSCALE:
|
|
6356
|
-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
6357
|
-
|
|
6696
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6697
|
+
int mode = ggml_get_op_params_i32(dst, 0);
|
|
6698
|
+
switch (mode) {
|
|
6699
|
+
case GGML_SCALE_MODE_NEAREST:
|
|
6700
|
+
return ctx->device->pipeline_upscale_nearest_f32;
|
|
6701
|
+
case GGML_SCALE_MODE_BILINEAR:
|
|
6702
|
+
return ctx->device->pipeline_upscale_bilinear_f32;
|
|
6703
|
+
case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
|
|
6704
|
+
return ctx->device->pipeline_upscale_bilinear_ac_f32;
|
|
6705
|
+
}
|
|
6358
6706
|
}
|
|
6359
6707
|
return nullptr;
|
|
6360
6708
|
case GGML_OP_SCALE:
|
|
@@ -6387,6 +6735,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6387
6735
|
return ctx->device->pipeline_pad_f32;
|
|
6388
6736
|
}
|
|
6389
6737
|
return nullptr;
|
|
6738
|
+
case GGML_OP_ROLL:
|
|
6739
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6740
|
+
return ctx->device->pipeline_roll_f32;
|
|
6741
|
+
}
|
|
6742
|
+
return nullptr;
|
|
6390
6743
|
case GGML_OP_REPEAT:
|
|
6391
6744
|
if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
|
|
6392
6745
|
return ctx->device->pipeline_repeat_f32;
|
|
@@ -6401,6 +6754,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6401
6754
|
case GGML_OP_CONT:
|
|
6402
6755
|
case GGML_OP_DUP:
|
|
6403
6756
|
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
|
|
6757
|
+
case GGML_OP_SET_ROWS:
|
|
6758
|
+
return ctx->device->pipeline_set_rows[dst->type];
|
|
6404
6759
|
case GGML_OP_SILU_BACK:
|
|
6405
6760
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6406
6761
|
return ctx->device->pipeline_silu_back_f32;
|
|
@@ -6418,7 +6773,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6418
6773
|
return nullptr;
|
|
6419
6774
|
case GGML_OP_RMS_NORM:
|
|
6420
6775
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6421
|
-
return ctx->device->pipeline_rms_norm_f32;
|
|
6776
|
+
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
|
|
6422
6777
|
}
|
|
6423
6778
|
return nullptr;
|
|
6424
6779
|
case GGML_OP_RMS_NORM_BACK:
|
|
@@ -6443,6 +6798,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6443
6798
|
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
|
|
6444
6799
|
case GGML_UNARY_OP_GELU:
|
|
6445
6800
|
return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
|
|
6801
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
6802
|
+
return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
|
|
6446
6803
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
6447
6804
|
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
|
|
6448
6805
|
case GGML_UNARY_OP_RELU:
|
|
@@ -6455,6 +6812,28 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6455
6812
|
break;
|
|
6456
6813
|
}
|
|
6457
6814
|
return nullptr;
|
|
6815
|
+
case GGML_OP_GLU:
|
|
6816
|
+
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
|
|
6817
|
+
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
|
|
6818
|
+
(src0->type != dst->type)) {
|
|
6819
|
+
return nullptr;
|
|
6820
|
+
}
|
|
6821
|
+
|
|
6822
|
+
switch (ggml_get_glu_op(dst)) {
|
|
6823
|
+
case GGML_GLU_OP_GEGLU:
|
|
6824
|
+
return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
|
|
6825
|
+
case GGML_GLU_OP_REGLU:
|
|
6826
|
+
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
|
6827
|
+
case GGML_GLU_OP_SWIGLU:
|
|
6828
|
+
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
|
6829
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
6830
|
+
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
|
|
6831
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
6832
|
+
return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
|
|
6833
|
+
default:
|
|
6834
|
+
break;
|
|
6835
|
+
}
|
|
6836
|
+
return nullptr;
|
|
6458
6837
|
case GGML_OP_DIAG_MASK_INF:
|
|
6459
6838
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6460
6839
|
return ctx->device->pipeline_diag_mask_inf_f32;
|
|
@@ -6578,6 +6957,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
6578
6957
|
return ctx->device->pipeline_leaky_relu_f32;
|
|
6579
6958
|
}
|
|
6580
6959
|
return nullptr;
|
|
6960
|
+
case GGML_OP_CONV_2D:
|
|
6961
|
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
|
6962
|
+
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
|
6963
|
+
return ctx->device->pipeline_conv2d_f32;
|
|
6964
|
+
}
|
|
6965
|
+
return nullptr;
|
|
6581
6966
|
case GGML_OP_CONV_2D_DW:
|
|
6582
6967
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6583
6968
|
if (ggml_is_contiguous(src1)) {
|
|
@@ -6615,6 +7000,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
|
6615
7000
|
case GGML_OP_RMS_NORM:
|
|
6616
7001
|
case GGML_OP_CONV_2D_DW:
|
|
6617
7002
|
case GGML_OP_IM2COL:
|
|
7003
|
+
case GGML_OP_SET_ROWS:
|
|
6618
7004
|
return true;
|
|
6619
7005
|
default:
|
|
6620
7006
|
return false;
|
|
@@ -6899,6 +7285,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6899
7285
|
const uint32_t OW = dst->ne[0];
|
|
6900
7286
|
elements = { N * OC * OH * OW, 1, 1};
|
|
6901
7287
|
} break;
|
|
7288
|
+
case GGML_OP_CONV_2D:
|
|
7289
|
+
{
|
|
7290
|
+
// src0 - kernel: [KW, KH, Cin, Cout]
|
|
7291
|
+
// src1 - input: [W, H, Cin, N]
|
|
7292
|
+
// dst - result: [OW, OH, Cout, N]
|
|
7293
|
+
|
|
7294
|
+
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
|
|
7295
|
+
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
|
|
7296
|
+
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
|
|
7297
|
+
};
|
|
7298
|
+
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
|
|
7299
|
+
int64_t W = src1->ne[0];
|
|
7300
|
+
int64_t H = src1->ne[1];
|
|
7301
|
+
int64_t KW = src0->ne[0];
|
|
7302
|
+
int64_t KH = src0->ne[1];
|
|
7303
|
+
int64_t Cout = src0->ne[3];
|
|
7304
|
+
int64_t N = src1->ne[3];
|
|
7305
|
+
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
|
|
7306
|
+
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
|
|
7307
|
+
int64_t NPQ = N * OW * OH;
|
|
7308
|
+
|
|
7309
|
+
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
|
|
7310
|
+
elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
|
|
7311
|
+
}
|
|
7312
|
+
break;
|
|
6902
7313
|
case GGML_OP_ADD:
|
|
6903
7314
|
case GGML_OP_SUB:
|
|
6904
7315
|
case GGML_OP_DIV:
|
|
@@ -6909,12 +7320,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6909
7320
|
case GGML_OP_COS:
|
|
6910
7321
|
case GGML_OP_CLAMP:
|
|
6911
7322
|
case GGML_OP_PAD:
|
|
7323
|
+
case GGML_OP_ROLL:
|
|
6912
7324
|
case GGML_OP_REPEAT:
|
|
6913
7325
|
case GGML_OP_REPEAT_BACK:
|
|
6914
7326
|
case GGML_OP_CPY:
|
|
6915
7327
|
case GGML_OP_CONCAT:
|
|
6916
7328
|
case GGML_OP_UPSCALE:
|
|
6917
7329
|
case GGML_OP_UNARY:
|
|
7330
|
+
case GGML_OP_GLU:
|
|
6918
7331
|
case GGML_OP_CONV_2D_DW:
|
|
6919
7332
|
{
|
|
6920
7333
|
uint32_t ne = ggml_nelements(dst);
|
|
@@ -6927,6 +7340,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6927
7340
|
ne *= ggml_type_size(src0->type) / 2;
|
|
6928
7341
|
}
|
|
6929
7342
|
}
|
|
7343
|
+
// copy_to_quant has block size of 32, and each thread does QUANT_K elements.
|
|
7344
|
+
// Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
|
|
7345
|
+
// So divide by block size here before splitting into 512x512 groups.
|
|
7346
|
+
if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
|
7347
|
+
ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
|
|
7348
|
+
}
|
|
6930
7349
|
if (ne > 262144) {
|
|
6931
7350
|
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
|
6932
7351
|
} else if (ne > 512) {
|
|
@@ -6935,6 +7354,25 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6935
7354
|
elements = { ne, 1, 1 };
|
|
6936
7355
|
}
|
|
6937
7356
|
} break;
|
|
7357
|
+
case GGML_OP_SET_ROWS:
|
|
7358
|
+
{
|
|
7359
|
+
uint32_t ne = ggml_nelements(src0);
|
|
7360
|
+
if (ggml_is_quantized(dst->type)) {
|
|
7361
|
+
// quants run 32 threads each doing QUANT_K elements
|
|
7362
|
+
ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
|
|
7363
|
+
} else {
|
|
7364
|
+
// scalar types do one element per thread, running 512 threads
|
|
7365
|
+
ne = CEIL_DIV(ne, 512);
|
|
7366
|
+
}
|
|
7367
|
+
if (ne > 262144) {
|
|
7368
|
+
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
|
7369
|
+
} else if (ne > 512) {
|
|
7370
|
+
elements = { 512, CEIL_DIV(ne, 512), 1 };
|
|
7371
|
+
} else {
|
|
7372
|
+
elements = { ne, 1, 1 };
|
|
7373
|
+
}
|
|
7374
|
+
}
|
|
7375
|
+
break;
|
|
6938
7376
|
default:
|
|
6939
7377
|
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
|
|
6940
7378
|
break;
|
|
@@ -6955,7 +7393,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6955
7393
|
}
|
|
6956
7394
|
}
|
|
6957
7395
|
|
|
6958
|
-
if (op == GGML_OP_SOFT_MAX) {
|
|
7396
|
+
if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
|
|
6959
7397
|
// Empty src1 is possible in soft_max, but the shader needs a buffer
|
|
6960
7398
|
vk_subbuffer subbuf_y;
|
|
6961
7399
|
if (use_src1) {
|
|
@@ -7344,14 +7782,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
7344
7782
|
|
|
7345
7783
|
static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7346
7784
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7785
|
+
const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
|
|
7786
|
+
|
|
7787
|
+
float sf0 = (float)dst->ne[0] / src0->ne[0];
|
|
7788
|
+
float sf1 = (float)dst->ne[1] / src0->ne[1];
|
|
7789
|
+
float sf2 = (float)dst->ne[2] / src0->ne[2];
|
|
7790
|
+
float sf3 = (float)dst->ne[3] / src0->ne[3];
|
|
7347
7791
|
|
|
7348
|
-
|
|
7349
|
-
|
|
7350
|
-
|
|
7351
|
-
|
|
7792
|
+
if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7793
|
+
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
|
|
7794
|
+
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
|
|
7795
|
+
}
|
|
7352
7796
|
|
|
7353
7797
|
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
|
|
7354
7798
|
(uint32_t)ggml_nelements(dst), 0, 0,
|
|
7799
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
|
|
7355
7800
|
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7356
7801
|
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
|
|
7357
7802
|
sf0, sf1, sf2, sf3,
|
|
@@ -7359,123 +7804,64 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
7359
7804
|
}
|
|
7360
7805
|
|
|
7361
7806
|
static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7362
|
-
|
|
7363
|
-
|
|
7364
|
-
|
|
7807
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
|
7808
|
+
p.param1 = ggml_get_op_params_f32(dst, 0);
|
|
7809
|
+
p.param2 = ggml_get_op_params_f32(dst, 1);
|
|
7365
7810
|
|
|
7366
|
-
ggml_vk_op_f32
|
|
7367
|
-
(uint32_t)ggml_nelements(src0),
|
|
7368
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7369
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7370
|
-
0,
|
|
7371
|
-
op_params[0], 0.0f,
|
|
7372
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7373
|
-
}, dryrun);
|
|
7811
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
|
|
7374
7812
|
}
|
|
7375
7813
|
|
|
7376
7814
|
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7377
|
-
|
|
7378
|
-
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7379
|
-
|
|
7380
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
|
|
7381
|
-
(uint32_t)ggml_nelements(src0),
|
|
7382
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7383
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7384
|
-
0,
|
|
7385
|
-
0.0f, 0.0f,
|
|
7386
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7387
|
-
}, dryrun);
|
|
7815
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
|
7388
7816
|
}
|
|
7389
7817
|
|
|
7390
7818
|
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7391
|
-
|
|
7392
|
-
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7393
|
-
|
|
7394
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
|
|
7395
|
-
(uint32_t)ggml_nelements(src0),
|
|
7396
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7397
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7398
|
-
0,
|
|
7399
|
-
0.0f, 0.0f,
|
|
7400
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7401
|
-
}, dryrun);
|
|
7819
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
|
7402
7820
|
}
|
|
7403
7821
|
|
|
7404
7822
|
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7405
|
-
|
|
7406
|
-
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7407
|
-
|
|
7408
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
|
|
7409
|
-
(uint32_t)ggml_nelements(src0),
|
|
7410
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7411
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7412
|
-
0,
|
|
7413
|
-
0.0f, 0.0f,
|
|
7414
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7415
|
-
}, dryrun);
|
|
7823
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
|
|
7416
7824
|
}
|
|
7417
7825
|
|
|
7418
7826
|
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7419
|
-
|
|
7420
|
-
|
|
7421
|
-
|
|
7827
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
|
7828
|
+
p.param1 = ggml_get_op_params_f32(dst, 0);
|
|
7829
|
+
p.param2 = ggml_get_op_params_f32(dst, 1);
|
|
7422
7830
|
|
|
7423
|
-
ggml_vk_op_f32
|
|
7424
|
-
(uint32_t)ggml_nelements(src0),
|
|
7425
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7426
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7427
|
-
0,
|
|
7428
|
-
op_params[0], op_params[1],
|
|
7429
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7430
|
-
}, dryrun);
|
|
7831
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
|
|
7431
7832
|
}
|
|
7432
7833
|
|
|
7433
7834
|
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7434
|
-
|
|
7435
|
-
|
|
7835
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
|
7836
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
|
|
7837
|
+
}
|
|
7436
7838
|
|
|
7437
|
-
|
|
7438
|
-
|
|
7439
|
-
|
|
7440
|
-
|
|
7441
|
-
|
|
7442
|
-
|
|
7443
|
-
|
|
7444
|
-
|
|
7839
|
+
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7840
|
+
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
|
|
7841
|
+
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
|
|
7842
|
+
const int32_t s2 = ggml_get_op_params_i32(dst, 2);
|
|
7843
|
+
const int32_t s3 = ggml_get_op_params_i32(dst, 3);
|
|
7844
|
+
const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
|
|
7845
|
+
const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
|
|
7846
|
+
|
|
7847
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
|
7848
|
+
memcpy(&p.param1, &s01_packed, sizeof(float));
|
|
7849
|
+
memcpy(&p.param2, &s23_packed, sizeof(float));
|
|
7850
|
+
|
|
7851
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
|
|
7445
7852
|
}
|
|
7446
7853
|
|
|
7447
7854
|
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7448
|
-
|
|
7449
|
-
|
|
7450
|
-
|
|
7451
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
|
|
7452
|
-
(uint32_t)ggml_nelements(dst),
|
|
7453
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7454
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7455
|
-
0,
|
|
7456
|
-
0.0f, 0.0f,
|
|
7457
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7458
|
-
}, dryrun);
|
|
7855
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
|
7856
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
|
|
7459
7857
|
}
|
|
7460
7858
|
|
|
7461
7859
|
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7462
|
-
|
|
7463
|
-
|
|
7464
|
-
|
|
7465
|
-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
|
|
7466
|
-
(uint32_t)ggml_nelements(dst),
|
|
7467
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7468
|
-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7469
|
-
0,
|
|
7470
|
-
0.0f, 0.0f,
|
|
7471
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7472
|
-
}, dryrun);
|
|
7860
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
|
7861
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
|
|
7473
7862
|
}
|
|
7474
7863
|
|
|
7475
7864
|
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7476
|
-
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7477
|
-
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7478
|
-
|
|
7479
7865
|
uint32_t ne = (uint32_t)ggml_nelements(src0);
|
|
7480
7866
|
if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
|
7481
7867
|
// Convert from number of logical elements to 2- or 4-byte units.
|
|
@@ -7487,13 +7873,22 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
7487
7873
|
}
|
|
7488
7874
|
}
|
|
7489
7875
|
|
|
7490
|
-
|
|
7491
|
-
|
|
7492
|
-
|
|
7493
|
-
|
|
7876
|
+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
|
|
7877
|
+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
|
|
7878
|
+
}
|
|
7879
|
+
|
|
7880
|
+
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
7881
|
+
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7882
|
+
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
7883
|
+
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7884
|
+
|
|
7885
|
+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
|
|
7886
|
+
(uint32_t)ggml_nelements(src0),
|
|
7887
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7888
|
+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
|
7889
|
+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7494
7890
|
0,
|
|
7495
|
-
0.0f, 0.0f,
|
|
7496
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7891
|
+
0.0f, 0.0f, 0,
|
|
7497
7892
|
}, dryrun);
|
|
7498
7893
|
}
|
|
7499
7894
|
|
|
@@ -7518,18 +7913,18 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
7518
7913
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
|
|
7519
7914
|
}
|
|
7520
7915
|
|
|
7521
|
-
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7522
|
-
float * op_params = (float *)dst->op_params;
|
|
7916
|
+
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
|
|
7523
7917
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7918
|
+
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
7524
7919
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7525
7920
|
|
|
7526
|
-
ggml_vk_op_f32<
|
|
7921
|
+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
|
|
7527
7922
|
(uint32_t)ggml_nelements(src0),
|
|
7528
|
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
|
7529
|
-
(uint32_t)
|
|
7923
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7924
|
+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
|
7925
|
+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7530
7926
|
0,
|
|
7531
|
-
op_params[0], 0.0f,
|
|
7532
|
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7927
|
+
op_params[0], 0.0f, 0,
|
|
7533
7928
|
}, dryrun);
|
|
7534
7929
|
}
|
|
7535
7930
|
|
|
@@ -7547,6 +7942,25 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
|
|
|
7547
7942
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
|
7548
7943
|
}
|
|
7549
7944
|
|
|
7945
|
+
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
7946
|
+
const bool swapped = (bool)dst->op_params[1];
|
|
7947
|
+
const bool split = src1 != nullptr;
|
|
7948
|
+
|
|
7949
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
7950
|
+
|
|
7951
|
+
if (!split) {
|
|
7952
|
+
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
|
|
7953
|
+
} else {
|
|
7954
|
+
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
|
|
7955
|
+
GGML_ASSERT(src0->ne[0] == dst->ne[0]);
|
|
7956
|
+
GGML_ASSERT(src0->type == src1->type);
|
|
7957
|
+
}
|
|
7958
|
+
|
|
7959
|
+
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
|
|
7960
|
+
|
|
7961
|
+
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
|
|
7962
|
+
}
|
|
7963
|
+
|
|
7550
7964
|
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7551
7965
|
int32_t * op_params = (int32_t *)dst->op_params;
|
|
7552
7966
|
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
|
|
@@ -7562,7 +7976,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
7562
7976
|
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
|
|
7563
7977
|
const uint32_t nrows_y = (uint32_t)src0->ne[1];
|
|
7564
7978
|
|
|
7565
|
-
const uint32_t
|
|
7979
|
+
const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
|
|
7980
|
+
const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
|
|
7981
|
+
const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
|
|
7982
|
+
const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
|
|
7983
|
+
const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
|
|
7984
|
+
|
|
7985
|
+
const uint32_t n_head_kv = src0->ne[2];
|
|
7566
7986
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
|
7567
7987
|
|
|
7568
7988
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
@@ -7571,6 +7991,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
7571
7991
|
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
|
|
7572
7992
|
ncols,
|
|
7573
7993
|
src1 != nullptr ? nrows_y : (uint32_t)0,
|
|
7994
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
|
7995
|
+
ne12, ne13,
|
|
7996
|
+
nb11, nb12, nb13,
|
|
7574
7997
|
scale, max_bias,
|
|
7575
7998
|
m0, m1,
|
|
7576
7999
|
n_head_log2,
|
|
@@ -7753,6 +8176,55 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
7753
8176
|
}, dryrun);
|
|
7754
8177
|
}
|
|
7755
8178
|
|
|
8179
|
+
static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
|
|
8180
|
+
const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
8181
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
8182
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
8183
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
8184
|
+
|
|
8185
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
8186
|
+
|
|
8187
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
|
8188
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
|
8189
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
8190
|
+
|
|
8191
|
+
vk_op_conv2d_push_constants p{};
|
|
8192
|
+
p.Cout = static_cast<uint32_t>(ne03);
|
|
8193
|
+
p.Cin = static_cast<uint32_t>(ne02);
|
|
8194
|
+
p.N = static_cast<uint32_t>(ne13);
|
|
8195
|
+
|
|
8196
|
+
p.KW = static_cast<uint32_t>(ne00);
|
|
8197
|
+
p.KH = static_cast<uint32_t>(ne01);
|
|
8198
|
+
p.W = static_cast<uint32_t>(ne10);
|
|
8199
|
+
p.H = static_cast<uint32_t>(ne11);
|
|
8200
|
+
p.OW = static_cast<uint32_t>(ne0);
|
|
8201
|
+
p.OH = static_cast<uint32_t>(ne1);
|
|
8202
|
+
|
|
8203
|
+
p.s0 = static_cast<uint32_t>(dst->op_params[0]);
|
|
8204
|
+
p.s1 = static_cast<uint32_t>(dst->op_params[1]);
|
|
8205
|
+
p.p0 = static_cast<uint32_t>(dst->op_params[2]);
|
|
8206
|
+
p.p1 = static_cast<uint32_t>(dst->op_params[3]);
|
|
8207
|
+
p.d0 = static_cast<uint32_t>(dst->op_params[4]);
|
|
8208
|
+
p.d1 = static_cast<uint32_t>(dst->op_params[5]);
|
|
8209
|
+
|
|
8210
|
+
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
|
|
8211
|
+
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
|
|
8212
|
+
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
|
|
8213
|
+
|
|
8214
|
+
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
|
|
8215
|
+
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
|
|
8216
|
+
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
|
|
8217
|
+
|
|
8218
|
+
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
|
|
8219
|
+
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
|
|
8220
|
+
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
|
|
8221
|
+
|
|
8222
|
+
GGML_ASSERT(ne03 == ne2);
|
|
8223
|
+
GGML_ASSERT(ne02 == ne12);
|
|
8224
|
+
|
|
8225
|
+
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
|
|
8226
|
+
}
|
|
8227
|
+
|
|
7756
8228
|
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
7757
8229
|
vk_op_conv2d_dw_push_constants p{};
|
|
7758
8230
|
p.ne = ggml_nelements(dst);
|
|
@@ -8720,11 +9192,12 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|
|
8720
9192
|
}
|
|
8721
9193
|
}
|
|
8722
9194
|
|
|
8723
|
-
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
|
|
9195
|
+
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
|
|
8724
9196
|
|
|
8725
9197
|
// Returns true if node has enqueued work into the queue, false otherwise
|
|
8726
9198
|
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
|
8727
|
-
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx,
|
|
9199
|
+
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
|
|
9200
|
+
ggml_tensor * node = cgraph->nodes[node_idx];
|
|
8728
9201
|
if (ggml_is_empty(node) || !node->buffer) {
|
|
8729
9202
|
return false;
|
|
8730
9203
|
}
|
|
@@ -8749,6 +9222,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8749
9222
|
switch (ggml_get_unary_op(node)) {
|
|
8750
9223
|
case GGML_UNARY_OP_SILU:
|
|
8751
9224
|
case GGML_UNARY_OP_GELU:
|
|
9225
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
8752
9226
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
8753
9227
|
case GGML_UNARY_OP_RELU:
|
|
8754
9228
|
case GGML_UNARY_OP_TANH:
|
|
@@ -8758,6 +9232,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8758
9232
|
return false;
|
|
8759
9233
|
}
|
|
8760
9234
|
break;
|
|
9235
|
+
case GGML_OP_GLU:
|
|
9236
|
+
switch (ggml_get_glu_op(node)) {
|
|
9237
|
+
case GGML_GLU_OP_GEGLU:
|
|
9238
|
+
case GGML_GLU_OP_REGLU:
|
|
9239
|
+
case GGML_GLU_OP_SWIGLU:
|
|
9240
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9241
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9242
|
+
break;
|
|
9243
|
+
default:
|
|
9244
|
+
return false;
|
|
9245
|
+
}
|
|
9246
|
+
break;
|
|
8761
9247
|
case GGML_OP_REPEAT:
|
|
8762
9248
|
case GGML_OP_REPEAT_BACK:
|
|
8763
9249
|
case GGML_OP_GET_ROWS:
|
|
@@ -8774,7 +9260,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8774
9260
|
case GGML_OP_COS:
|
|
8775
9261
|
case GGML_OP_CLAMP:
|
|
8776
9262
|
case GGML_OP_PAD:
|
|
9263
|
+
case GGML_OP_ROLL:
|
|
8777
9264
|
case GGML_OP_CPY:
|
|
9265
|
+
case GGML_OP_SET_ROWS:
|
|
8778
9266
|
case GGML_OP_CONT:
|
|
8779
9267
|
case GGML_OP_DUP:
|
|
8780
9268
|
case GGML_OP_SILU_BACK:
|
|
@@ -8799,6 +9287,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8799
9287
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
8800
9288
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
8801
9289
|
case GGML_OP_POOL_2D:
|
|
9290
|
+
case GGML_OP_CONV_2D:
|
|
8802
9291
|
case GGML_OP_CONV_2D_DW:
|
|
8803
9292
|
case GGML_OP_RWKV_WKV6:
|
|
8804
9293
|
case GGML_OP_RWKV_WKV7:
|
|
@@ -8841,6 +9330,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8841
9330
|
case GGML_OP_CLAMP:
|
|
8842
9331
|
case GGML_OP_PAD:
|
|
8843
9332
|
case GGML_OP_CPY:
|
|
9333
|
+
case GGML_OP_SET_ROWS:
|
|
8844
9334
|
case GGML_OP_CONT:
|
|
8845
9335
|
case GGML_OP_DUP:
|
|
8846
9336
|
case GGML_OP_SILU_BACK:
|
|
@@ -8850,6 +9340,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8850
9340
|
case GGML_OP_RMS_NORM_BACK:
|
|
8851
9341
|
case GGML_OP_L2_NORM:
|
|
8852
9342
|
case GGML_OP_UNARY:
|
|
9343
|
+
case GGML_OP_GLU:
|
|
8853
9344
|
case GGML_OP_DIAG_MASK_INF:
|
|
8854
9345
|
case GGML_OP_SOFT_MAX:
|
|
8855
9346
|
case GGML_OP_SOFT_MAX_BACK:
|
|
@@ -8864,6 +9355,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8864
9355
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
8865
9356
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
8866
9357
|
case GGML_OP_POOL_2D:
|
|
9358
|
+
case GGML_OP_CONV_2D:
|
|
8867
9359
|
case GGML_OP_CONV_2D_DW:
|
|
8868
9360
|
case GGML_OP_LEAKY_RELU:
|
|
8869
9361
|
{
|
|
@@ -8942,12 +9434,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8942
9434
|
case GGML_OP_PAD:
|
|
8943
9435
|
ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
|
|
8944
9436
|
|
|
9437
|
+
break;
|
|
9438
|
+
case GGML_OP_ROLL:
|
|
9439
|
+
ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
|
|
9440
|
+
|
|
8945
9441
|
break;
|
|
8946
9442
|
case GGML_OP_CPY:
|
|
8947
9443
|
case GGML_OP_CONT:
|
|
8948
9444
|
case GGML_OP_DUP:
|
|
8949
9445
|
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
|
|
8950
9446
|
|
|
9447
|
+
break;
|
|
9448
|
+
case GGML_OP_SET_ROWS:
|
|
9449
|
+
ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
9450
|
+
|
|
8951
9451
|
break;
|
|
8952
9452
|
case GGML_OP_SILU_BACK:
|
|
8953
9453
|
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -8962,8 +9462,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8962
9462
|
|
|
8963
9463
|
break;
|
|
8964
9464
|
case GGML_OP_RMS_NORM:
|
|
8965
|
-
|
|
8966
|
-
|
|
9465
|
+
if (ctx->num_additional_fused_ops > 0) {
|
|
9466
|
+
// fused rms_norm + mul
|
|
9467
|
+
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
|
9468
|
+
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
|
|
9469
|
+
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
|
|
9470
|
+
} else {
|
|
9471
|
+
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
|
|
9472
|
+
}
|
|
8967
9473
|
break;
|
|
8968
9474
|
case GGML_OP_RMS_NORM_BACK:
|
|
8969
9475
|
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -8977,6 +9483,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8977
9483
|
switch (ggml_get_unary_op(node)) {
|
|
8978
9484
|
case GGML_UNARY_OP_SILU:
|
|
8979
9485
|
case GGML_UNARY_OP_GELU:
|
|
9486
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
8980
9487
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
8981
9488
|
case GGML_UNARY_OP_RELU:
|
|
8982
9489
|
case GGML_UNARY_OP_TANH:
|
|
@@ -8987,6 +9494,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8987
9494
|
return false;
|
|
8988
9495
|
}
|
|
8989
9496
|
break;
|
|
9497
|
+
case GGML_OP_GLU:
|
|
9498
|
+
switch (ggml_get_glu_op(node)) {
|
|
9499
|
+
case GGML_GLU_OP_GEGLU:
|
|
9500
|
+
case GGML_GLU_OP_REGLU:
|
|
9501
|
+
case GGML_GLU_OP_SWIGLU:
|
|
9502
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9503
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9504
|
+
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
9505
|
+
break;
|
|
9506
|
+
default:
|
|
9507
|
+
return false;
|
|
9508
|
+
}
|
|
9509
|
+
break;
|
|
8990
9510
|
case GGML_OP_DIAG_MASK_INF:
|
|
8991
9511
|
ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
|
|
8992
9512
|
|
|
@@ -9042,6 +9562,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
9042
9562
|
case GGML_OP_POOL_2D:
|
|
9043
9563
|
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
|
|
9044
9564
|
|
|
9565
|
+
break;
|
|
9566
|
+
case GGML_OP_CONV_2D:
|
|
9567
|
+
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
9568
|
+
|
|
9045
9569
|
break;
|
|
9046
9570
|
case GGML_OP_CONV_2D_DW:
|
|
9047
9571
|
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -9108,12 +9632,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
9108
9632
|
|
|
9109
9633
|
ctx->compute_ctx.reset();
|
|
9110
9634
|
|
|
9111
|
-
bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
|
|
9635
|
+
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
|
|
9112
9636
|
if (!ok) {
|
|
9113
9637
|
if (node->op == GGML_OP_UNARY) {
|
|
9114
9638
|
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
|
|
9115
|
-
}
|
|
9116
|
-
|
|
9639
|
+
} else if (node->op == GGML_OP_GLU) {
|
|
9640
|
+
std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
|
|
9641
|
+
} else {
|
|
9117
9642
|
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
|
|
9118
9643
|
}
|
|
9119
9644
|
}
|
|
@@ -9122,7 +9647,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
9122
9647
|
return true;
|
|
9123
9648
|
}
|
|
9124
9649
|
|
|
9125
|
-
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
|
|
9650
|
+
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
|
|
9651
|
+
GGML_UNUSED(cgraph);
|
|
9126
9652
|
ggml_backend_buffer * buf = nullptr;
|
|
9127
9653
|
|
|
9128
9654
|
switch (tensor->op) {
|
|
@@ -9140,7 +9666,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9140
9666
|
case GGML_OP_COS:
|
|
9141
9667
|
case GGML_OP_CLAMP:
|
|
9142
9668
|
case GGML_OP_PAD:
|
|
9669
|
+
case GGML_OP_ROLL:
|
|
9143
9670
|
case GGML_OP_CPY:
|
|
9671
|
+
case GGML_OP_SET_ROWS:
|
|
9144
9672
|
case GGML_OP_CONT:
|
|
9145
9673
|
case GGML_OP_DUP:
|
|
9146
9674
|
case GGML_OP_SILU_BACK:
|
|
@@ -9168,6 +9696,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9168
9696
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
9169
9697
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
9170
9698
|
case GGML_OP_POOL_2D:
|
|
9699
|
+
case GGML_OP_CONV_2D:
|
|
9171
9700
|
case GGML_OP_CONV_2D_DW:
|
|
9172
9701
|
case GGML_OP_RWKV_WKV6:
|
|
9173
9702
|
case GGML_OP_RWKV_WKV7:
|
|
@@ -9182,6 +9711,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9182
9711
|
switch (ggml_get_unary_op(tensor)) {
|
|
9183
9712
|
case GGML_UNARY_OP_SILU:
|
|
9184
9713
|
case GGML_UNARY_OP_GELU:
|
|
9714
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
9185
9715
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
9186
9716
|
case GGML_UNARY_OP_RELU:
|
|
9187
9717
|
case GGML_UNARY_OP_TANH:
|
|
@@ -9192,6 +9722,19 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9192
9722
|
return false;
|
|
9193
9723
|
}
|
|
9194
9724
|
break;
|
|
9725
|
+
case GGML_OP_GLU:
|
|
9726
|
+
switch (ggml_get_glu_op(tensor)) {
|
|
9727
|
+
case GGML_GLU_OP_GEGLU:
|
|
9728
|
+
case GGML_GLU_OP_REGLU:
|
|
9729
|
+
case GGML_GLU_OP_SWIGLU:
|
|
9730
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
9731
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
9732
|
+
buf = tensor->buffer;
|
|
9733
|
+
break;
|
|
9734
|
+
default:
|
|
9735
|
+
return false;
|
|
9736
|
+
}
|
|
9737
|
+
break;
|
|
9195
9738
|
case GGML_OP_MUL_MAT:
|
|
9196
9739
|
case GGML_OP_MUL_MAT_ID:
|
|
9197
9740
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
@@ -9218,7 +9761,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9218
9761
|
// Only run if ctx hasn't been submitted yet
|
|
9219
9762
|
if (!subctx->seqs.empty()) {
|
|
9220
9763
|
#ifdef GGML_VULKAN_CHECK_RESULTS
|
|
9221
|
-
ggml_vk_check_results_0(
|
|
9764
|
+
ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
|
|
9222
9765
|
use_fence = true;
|
|
9223
9766
|
#endif
|
|
9224
9767
|
|
|
@@ -9238,7 +9781,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
9238
9781
|
ggml_vk_wait_for_fence(ctx);
|
|
9239
9782
|
}
|
|
9240
9783
|
#ifdef GGML_VULKAN_CHECK_RESULTS
|
|
9241
|
-
ggml_vk_check_results_1(
|
|
9784
|
+
ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
|
|
9242
9785
|
#endif
|
|
9243
9786
|
}
|
|
9244
9787
|
|
|
@@ -9685,6 +10228,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
|
|
|
9685
10228
|
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
|
|
9686
10229
|
}
|
|
9687
10230
|
|
|
10231
|
+
static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
|
|
10232
|
+
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
|
10233
|
+
return false;
|
|
10234
|
+
}
|
|
10235
|
+
|
|
10236
|
+
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
|
|
10237
|
+
// additional constraints specific to this fusion
|
|
10238
|
+
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
|
|
10239
|
+
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
|
10240
|
+
|
|
10241
|
+
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
|
|
10242
|
+
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
|
|
10243
|
+
// rms_norm only supports f32
|
|
10244
|
+
if (mul->src[0]->type != GGML_TYPE_F32 ||
|
|
10245
|
+
mul->src[1]->type != GGML_TYPE_F32 ||
|
|
10246
|
+
mul->type != GGML_TYPE_F32) {
|
|
10247
|
+
return false;
|
|
10248
|
+
}
|
|
10249
|
+
// if rms_norm is the B operand, then we don't handle broadcast
|
|
10250
|
+
if (rms_norm == mul->src[1] &&
|
|
10251
|
+
!ggml_are_same_shape(mul->src[0], rms_norm)) {
|
|
10252
|
+
return false;
|
|
10253
|
+
}
|
|
10254
|
+
// rms_norm shader assumes contiguous rows
|
|
10255
|
+
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
|
|
10256
|
+
return false;
|
|
10257
|
+
}
|
|
10258
|
+
}
|
|
10259
|
+
return true;
|
|
10260
|
+
}
|
|
10261
|
+
|
|
9688
10262
|
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
|
9689
10263
|
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
|
9690
10264
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
@@ -9698,10 +10272,21 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9698
10272
|
|
|
9699
10273
|
uint64_t total_mat_mul_bytes = 0;
|
|
9700
10274
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
9701
|
-
|
|
10275
|
+
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
10276
|
+
ctx->num_additional_fused_ops = 1;
|
|
10277
|
+
}
|
|
10278
|
+
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
|
9702
10279
|
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
|
9703
10280
|
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
10281
|
+
} else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) {
|
|
10282
|
+
// Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
|
|
10283
|
+
auto CRS_size =
|
|
10284
|
+
cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2];
|
|
10285
|
+
auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3];
|
|
10286
|
+
total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type);
|
|
9704
10287
|
}
|
|
10288
|
+
i += ctx->num_additional_fused_ops;
|
|
10289
|
+
ctx->num_additional_fused_ops = 0;
|
|
9705
10290
|
}
|
|
9706
10291
|
if (ctx->device->need_compiles) {
|
|
9707
10292
|
ggml_vk_load_shaders(ctx->device);
|
|
@@ -9763,14 +10348,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9763
10348
|
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
9764
10349
|
}
|
|
9765
10350
|
|
|
10351
|
+
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
10352
|
+
ctx->num_additional_fused_ops = 1;
|
|
10353
|
+
}
|
|
10354
|
+
|
|
9766
10355
|
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
|
|
9767
10356
|
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
|
9768
10357
|
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
|
9769
10358
|
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
|
9770
|
-
(i == last_node) ||
|
|
10359
|
+
(i + ctx->num_additional_fused_ops == last_node) ||
|
|
9771
10360
|
(almost_ready && !ctx->almost_ready_fence_pending);
|
|
9772
10361
|
|
|
9773
|
-
bool enqueued = ggml_vk_build_graph(ctx, cgraph
|
|
10362
|
+
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
|
|
9774
10363
|
|
|
9775
10364
|
if (vk_perf_logger_enabled) {
|
|
9776
10365
|
if (ctx->compute_ctx.expired()) {
|
|
@@ -9780,7 +10369,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9780
10369
|
} else {
|
|
9781
10370
|
compute_ctx = ctx->compute_ctx.lock();
|
|
9782
10371
|
}
|
|
9783
|
-
|
|
10372
|
+
// If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
|
|
10373
|
+
for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
|
|
10374
|
+
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
|
|
10375
|
+
}
|
|
9784
10376
|
}
|
|
9785
10377
|
|
|
9786
10378
|
if (enqueued) {
|
|
@@ -9802,6 +10394,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9802
10394
|
}
|
|
9803
10395
|
submit_count++;
|
|
9804
10396
|
}
|
|
10397
|
+
i += ctx->num_additional_fused_ops;
|
|
10398
|
+
ctx->num_additional_fused_ops = 0;
|
|
9805
10399
|
}
|
|
9806
10400
|
|
|
9807
10401
|
if (vk_perf_logger_enabled) {
|
|
@@ -9963,6 +10557,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9963
10557
|
case GGML_OP_UNARY:
|
|
9964
10558
|
switch (ggml_get_unary_op(op)) {
|
|
9965
10559
|
case GGML_UNARY_OP_GELU:
|
|
10560
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
9966
10561
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
9967
10562
|
case GGML_UNARY_OP_SILU:
|
|
9968
10563
|
case GGML_UNARY_OP_RELU:
|
|
@@ -9976,15 +10571,32 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9976
10571
|
return false;
|
|
9977
10572
|
}
|
|
9978
10573
|
break;
|
|
10574
|
+
case GGML_OP_GLU:
|
|
10575
|
+
switch (ggml_get_glu_op(op)) {
|
|
10576
|
+
case GGML_GLU_OP_GEGLU:
|
|
10577
|
+
case GGML_GLU_OP_REGLU:
|
|
10578
|
+
case GGML_GLU_OP_SWIGLU:
|
|
10579
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
10580
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
10581
|
+
return ggml_is_contiguous(op->src[0]) &&
|
|
10582
|
+
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
|
10583
|
+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
|
10584
|
+
(op->src[0]->type == op->type);
|
|
10585
|
+
default:
|
|
10586
|
+
return false;
|
|
10587
|
+
}
|
|
10588
|
+
break;
|
|
9979
10589
|
case GGML_OP_MUL_MAT:
|
|
9980
10590
|
case GGML_OP_MUL_MAT_ID:
|
|
9981
10591
|
{
|
|
9982
10592
|
ggml_type src0_type = op->src[0]->type;
|
|
9983
10593
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
9984
10594
|
const vk_device& device = ggml_vk_get_device(ctx->device);
|
|
9985
|
-
if (op->op == GGML_OP_MUL_MAT_ID
|
|
9986
|
-
|
|
9987
|
-
|
|
10595
|
+
if (op->op == GGML_OP_MUL_MAT_ID) {
|
|
10596
|
+
if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
|
|
10597
|
+
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
|
10598
|
+
return false;
|
|
10599
|
+
}
|
|
9988
10600
|
}
|
|
9989
10601
|
switch (src0_type) {
|
|
9990
10602
|
case GGML_TYPE_F32:
|
|
@@ -10042,19 +10654,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10042
10654
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
10043
10655
|
auto device = ggml_vk_get_device(ctx->device);
|
|
10044
10656
|
bool coopmat2 = device->coopmat2;
|
|
10045
|
-
|
|
10046
|
-
|
|
10047
|
-
case 80:
|
|
10048
|
-
case 96:
|
|
10049
|
-
case 112:
|
|
10050
|
-
case 128:
|
|
10051
|
-
case 256:
|
|
10052
|
-
break;
|
|
10053
|
-
default:
|
|
10054
|
-
return false;
|
|
10055
|
-
}
|
|
10056
|
-
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
|
10057
|
-
// different head sizes of K and V are not supported yet
|
|
10657
|
+
FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
|
|
10658
|
+
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
|
|
10058
10659
|
return false;
|
|
10059
10660
|
}
|
|
10060
10661
|
if (op->src[0]->type != GGML_TYPE_F32) {
|
|
@@ -10134,6 +10735,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10134
10735
|
return false;
|
|
10135
10736
|
}
|
|
10136
10737
|
} break;
|
|
10738
|
+
case GGML_OP_SET_ROWS:
|
|
10739
|
+
{
|
|
10740
|
+
switch (op->type) {
|
|
10741
|
+
case GGML_TYPE_F32:
|
|
10742
|
+
case GGML_TYPE_F16:
|
|
10743
|
+
case GGML_TYPE_BF16:
|
|
10744
|
+
case GGML_TYPE_Q4_0:
|
|
10745
|
+
case GGML_TYPE_Q4_1:
|
|
10746
|
+
case GGML_TYPE_Q5_0:
|
|
10747
|
+
case GGML_TYPE_Q5_1:
|
|
10748
|
+
case GGML_TYPE_Q8_0:
|
|
10749
|
+
case GGML_TYPE_IQ4_NL:
|
|
10750
|
+
return true;
|
|
10751
|
+
default:
|
|
10752
|
+
return false;
|
|
10753
|
+
}
|
|
10754
|
+
} break;
|
|
10137
10755
|
case GGML_OP_CONT:
|
|
10138
10756
|
case GGML_OP_CPY:
|
|
10139
10757
|
case GGML_OP_DUP:
|
|
@@ -10218,11 +10836,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10218
10836
|
case GGML_OP_CLAMP:
|
|
10219
10837
|
return op->src[0]->type == GGML_TYPE_F32;
|
|
10220
10838
|
case GGML_OP_UPSCALE:
|
|
10221
|
-
return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
|
10222
10839
|
case GGML_OP_ACC:
|
|
10223
10840
|
case GGML_OP_CONCAT:
|
|
10224
10841
|
case GGML_OP_SCALE:
|
|
10225
10842
|
case GGML_OP_PAD:
|
|
10843
|
+
case GGML_OP_ROLL:
|
|
10226
10844
|
case GGML_OP_DIAG_MASK_INF:
|
|
10227
10845
|
case GGML_OP_SOFT_MAX:
|
|
10228
10846
|
case GGML_OP_SOFT_MAX_BACK:
|
|
@@ -10242,6 +10860,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
10242
10860
|
return true;
|
|
10243
10861
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
10244
10862
|
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
|
10863
|
+
case GGML_OP_CONV_2D:
|
|
10864
|
+
{
|
|
10865
|
+
// Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
|
|
10866
|
+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
10867
|
+
const vk_device& device = ggml_vk_get_device(ctx->device);
|
|
10868
|
+
bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
|
|
10869
|
+
// Channel-contiguous format is not supported yet.
|
|
10870
|
+
return (op->src[0]->type == GGML_TYPE_F32 &&
|
|
10871
|
+
op->src[1]->type == GGML_TYPE_F32 &&
|
|
10872
|
+
op->type == GGML_TYPE_F32 &&
|
|
10873
|
+
ggml_is_contiguous(op->src[0]) &&
|
|
10874
|
+
ggml_is_contiguous(op->src[1]) &&
|
|
10875
|
+
ggml_is_contiguous(op)) && !is_Apple;
|
|
10876
|
+
}
|
|
10245
10877
|
default:
|
|
10246
10878
|
return false;
|
|
10247
10879
|
}
|
|
@@ -10513,11 +11145,21 @@ void * comp_result;
|
|
|
10513
11145
|
size_t comp_size;
|
|
10514
11146
|
size_t comp_nb[GGML_MAX_DIMS];
|
|
10515
11147
|
size_t check_counter = 0;
|
|
10516
|
-
static void ggml_vk_check_results_0(
|
|
11148
|
+
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
|
|
11149
|
+
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
|
|
10517
11150
|
if (tensor->op == GGML_OP_TRANSPOSE) {
|
|
10518
11151
|
return;
|
|
10519
11152
|
}
|
|
10520
11153
|
|
|
11154
|
+
bool fused_rms_norm_mul = false;
|
|
11155
|
+
int rms_norm_idx = -1;
|
|
11156
|
+
if (ctx->num_additional_fused_ops == 1 &&
|
|
11157
|
+
tensor->op == GGML_OP_RMS_NORM &&
|
|
11158
|
+
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
|
|
11159
|
+
fused_rms_norm_mul = true;
|
|
11160
|
+
tensor = cgraph->nodes[tensor_idx + 1];
|
|
11161
|
+
}
|
|
11162
|
+
|
|
10521
11163
|
check_counter++;
|
|
10522
11164
|
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
|
|
10523
11165
|
return;
|
|
@@ -10545,6 +11187,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10545
11187
|
|
|
10546
11188
|
for (int i = 0; i < 6; i++) {
|
|
10547
11189
|
ggml_tensor * srci = tensor->src[i];
|
|
11190
|
+
if (fused_rms_norm_mul) {
|
|
11191
|
+
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
|
|
11192
|
+
ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
|
|
11193
|
+
switch (i) {
|
|
11194
|
+
case 0: srci = rms_norm->src[0]; break;
|
|
11195
|
+
case 1: srci = tensor->src[1 - rms_norm_idx]; break;
|
|
11196
|
+
default: continue;
|
|
11197
|
+
}
|
|
11198
|
+
}
|
|
10548
11199
|
if (srci == nullptr) {
|
|
10549
11200
|
continue;
|
|
10550
11201
|
}
|
|
@@ -10602,7 +11253,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10602
11253
|
} else if (tensor->op == GGML_OP_SUB) {
|
|
10603
11254
|
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
|
|
10604
11255
|
} else if (tensor->op == GGML_OP_MUL) {
|
|
10605
|
-
|
|
11256
|
+
if (fused_rms_norm_mul) {
|
|
11257
|
+
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
|
|
11258
|
+
tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
|
|
11259
|
+
} else {
|
|
11260
|
+
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
|
|
11261
|
+
}
|
|
10606
11262
|
} else if (tensor->op == GGML_OP_DIV) {
|
|
10607
11263
|
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
|
|
10608
11264
|
} else if (tensor->op == GGML_OP_CONCAT) {
|
|
@@ -10690,6 +11346,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10690
11346
|
case GGML_UNARY_OP_GELU:
|
|
10691
11347
|
tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
|
|
10692
11348
|
break;
|
|
11349
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
11350
|
+
tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
|
|
11351
|
+
break;
|
|
10693
11352
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
10694
11353
|
tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
|
|
10695
11354
|
break;
|
|
@@ -10706,6 +11365,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10706
11365
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
|
10707
11366
|
GGML_ABORT("fatal error");
|
|
10708
11367
|
}
|
|
11368
|
+
} else if (tensor->op == GGML_OP_GLU) {
|
|
11369
|
+
if (src_clone[1] == nullptr) {
|
|
11370
|
+
tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
|
|
11371
|
+
} else {
|
|
11372
|
+
tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
|
|
11373
|
+
}
|
|
10709
11374
|
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
|
|
10710
11375
|
if (src1 == nullptr) {
|
|
10711
11376
|
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
|
|
@@ -10713,6 +11378,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10713
11378
|
} else {
|
|
10714
11379
|
tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
|
|
10715
11380
|
}
|
|
11381
|
+
} else if (tensor->op == GGML_OP_SET_ROWS) {
|
|
11382
|
+
tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
|
|
10716
11383
|
} else if (tensor->op == GGML_OP_CONT) {
|
|
10717
11384
|
tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
|
10718
11385
|
} else if (tensor->op == GGML_OP_RESHAPE) {
|
|
@@ -10765,6 +11432,14 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10765
11432
|
const int32_t p1 = tensor->op_params[6];
|
|
10766
11433
|
|
|
10767
11434
|
tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
|
|
11435
|
+
} else if (tensor->op == GGML_OP_CONV_2D) {
|
|
11436
|
+
const int32_t s0 = tensor->op_params[0];
|
|
11437
|
+
const int32_t s1 = tensor->op_params[1];
|
|
11438
|
+
const int32_t p0 = tensor->op_params[2];
|
|
11439
|
+
const int32_t p1 = tensor->op_params[3];
|
|
11440
|
+
const int32_t d0 = tensor->op_params[4];
|
|
11441
|
+
const int32_t d1 = tensor->op_params[5];
|
|
11442
|
+
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
|
|
10768
11443
|
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
|
|
10769
11444
|
const float * op_params = (const float *)tensor->op_params;
|
|
10770
11445
|
tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
|
|
@@ -10784,10 +11459,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10784
11459
|
GGML_ABORT("fatal error");
|
|
10785
11460
|
}
|
|
10786
11461
|
|
|
10787
|
-
ggml_cgraph *
|
|
10788
|
-
ggml_build_forward_expand(
|
|
11462
|
+
ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
|
|
11463
|
+
ggml_build_forward_expand(cgraph_cpu, tensor_clone);
|
|
10789
11464
|
|
|
10790
|
-
ggml_graph_compute_with_ctx(ggml_ctx,
|
|
11465
|
+
ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
|
|
10791
11466
|
|
|
10792
11467
|
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
|
|
10793
11468
|
ggml_vk_print_tensor(tensor_clone, "tensor_clone");
|
|
@@ -10810,10 +11485,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10810
11485
|
VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
|
|
10811
11486
|
}
|
|
10812
11487
|
|
|
10813
|
-
static void ggml_vk_check_results_1(
|
|
11488
|
+
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
|
|
11489
|
+
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
|
|
10814
11490
|
if (tensor->op == GGML_OP_TRANSPOSE) {
|
|
10815
11491
|
return;
|
|
10816
11492
|
}
|
|
11493
|
+
bool fused_rms_norm_mul = false;
|
|
11494
|
+
if (ctx->num_additional_fused_ops == 1 &&
|
|
11495
|
+
tensor->op == GGML_OP_RMS_NORM &&
|
|
11496
|
+
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
|
|
11497
|
+
fused_rms_norm_mul = true;
|
|
11498
|
+
tensor = cgraph->nodes[tensor_idx + 1];
|
|
11499
|
+
}
|
|
11500
|
+
|
|
10817
11501
|
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
|
|
10818
11502
|
return;
|
|
10819
11503
|
}
|