@novastera-oss/llamarn 0.2.1 → 0.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +80 -14
- package/RNLlamaCpp.podspec +10 -3
- package/android/CMakeLists.txt +8 -0
- package/android/src/main/cpp/include/llama.h +62 -125
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/PureCppImpl.cpp +9 -27
- package/cpp/SystemUtils.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/README.md +11 -3
- package/cpp/llama.cpp/build-xcframework.sh +1 -0
- package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/common/arg.cpp +153 -113
- package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
- package/cpp/llama.cpp/common/chat-parser.h +117 -0
- package/cpp/llama.cpp/common/chat.cpp +847 -699
- package/cpp/llama.cpp/common/chat.h +73 -6
- package/cpp/llama.cpp/common/common.cpp +50 -82
- package/cpp/llama.cpp/common/common.h +21 -17
- package/cpp/llama.cpp/common/json-partial.cpp +255 -0
- package/cpp/llama.cpp/common/json-partial.h +37 -0
- package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
- package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
- package/cpp/llama.cpp/common/regex-partial.h +56 -0
- package/cpp/llama.cpp/common/sampling.cpp +7 -8
- package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
- package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
- package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
- package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
- package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
- package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
- package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
- package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
- package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
- package/cpp/llama.cpp/include/llama.h +62 -125
- package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
- package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
- package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
- package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
- package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
- package/cpp/llama.cpp/src/llama-arch.h +2 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
- package/cpp/llama.cpp/src/llama-context.cpp +340 -123
- package/cpp/llama.cpp/src/llama-context.h +30 -0
- package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
- package/cpp/llama.cpp/src/llama-cparams.h +2 -0
- package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
- package/cpp/llama.cpp/src/llama-graph.h +52 -7
- package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
- package/cpp/llama.cpp/src/llama-hparams.h +37 -5
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
- package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
- package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
- package/cpp/llama.cpp/src/llama-memory.h +4 -3
- package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
- package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
- package/cpp/llama.cpp/src/llama-model.cpp +529 -172
- package/cpp/llama.cpp/src/llama-model.h +6 -1
- package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
- package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
- package/cpp/llama.cpp/src/llama-vocab.h +6 -0
- package/cpp/llama.cpp/src/llama.cpp +14 -0
- package/cpp/rn-completion.cpp +60 -5
- package/ios/include/chat.h +73 -6
- package/ios/include/common/minja/chat-template.hpp +9 -5
- package/ios/include/common/minja/minja.hpp +69 -36
- package/ios/include/common.h +21 -17
- package/ios/include/llama.h +62 -125
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/common/stb_image.h +0 -7988
- package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
#include "ggml-vulkan.h"
|
|
2
2
|
#include <vulkan/vulkan_core.h>
|
|
3
|
-
#if defined(GGML_VULKAN_RUN_TESTS) || defined(
|
|
3
|
+
#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS)
|
|
4
4
|
#include <chrono>
|
|
5
5
|
#include "ggml-cpu.h"
|
|
6
6
|
#endif
|
|
@@ -184,9 +184,7 @@ static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
|
|
|
184
184
|
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
|
185
185
|
class vk_memory_logger;
|
|
186
186
|
#endif
|
|
187
|
-
#ifdef GGML_VULKAN_PERF
|
|
188
187
|
class vk_perf_logger;
|
|
189
|
-
#endif
|
|
190
188
|
static void ggml_vk_destroy_buffer(vk_buffer& buf);
|
|
191
189
|
|
|
192
190
|
static constexpr uint32_t mul_mat_vec_max_cols = 8;
|
|
@@ -288,6 +286,9 @@ struct vk_device_struct {
|
|
|
288
286
|
bool coopmat_acc_f32_support {};
|
|
289
287
|
bool coopmat_acc_f16_support {};
|
|
290
288
|
bool coopmat_bf16_support {};
|
|
289
|
+
bool coopmat_support_16x16x16_f16acc {};
|
|
290
|
+
bool coopmat_support_16x16x16_f32acc {};
|
|
291
|
+
bool coopmat1_fa_support {};
|
|
291
292
|
uint32_t coopmat_m;
|
|
292
293
|
uint32_t coopmat_n;
|
|
293
294
|
uint32_t coopmat_k;
|
|
@@ -410,6 +411,13 @@ struct vk_device_struct {
|
|
|
410
411
|
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
411
412
|
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
412
413
|
|
|
414
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
415
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
416
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
417
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
418
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
419
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
420
|
+
|
|
413
421
|
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
|
414
422
|
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
|
415
423
|
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
|
@@ -432,9 +440,11 @@ struct vk_device_struct {
|
|
|
432
440
|
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
|
433
441
|
std::unique_ptr<vk_memory_logger> memory_logger;
|
|
434
442
|
#endif
|
|
435
|
-
|
|
443
|
+
|
|
444
|
+
// for GGML_VK_PERF_LOGGER
|
|
436
445
|
std::unique_ptr<vk_perf_logger> perf_logger;
|
|
437
|
-
|
|
446
|
+
vk::QueryPool query_pool;
|
|
447
|
+
uint32_t num_queries;
|
|
438
448
|
|
|
439
449
|
~vk_device_struct() {
|
|
440
450
|
VK_LOG_DEBUG("destroy device " << name);
|
|
@@ -818,8 +828,6 @@ private:
|
|
|
818
828
|
#define VK_LOG_MEMORY(msg) ((void) 0)
|
|
819
829
|
#endif // GGML_VULKAN_MEMORY_DEBUG
|
|
820
830
|
|
|
821
|
-
#if defined(GGML_VULKAN_PERF)
|
|
822
|
-
|
|
823
831
|
class vk_perf_logger {
|
|
824
832
|
public:
|
|
825
833
|
void print_timings() {
|
|
@@ -829,7 +837,7 @@ public:
|
|
|
829
837
|
for (const auto& time : t.second) {
|
|
830
838
|
total += time;
|
|
831
839
|
}
|
|
832
|
-
std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << "
|
|
840
|
+
std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl;
|
|
833
841
|
}
|
|
834
842
|
|
|
835
843
|
timings.clear();
|
|
@@ -858,7 +866,6 @@ public:
|
|
|
858
866
|
private:
|
|
859
867
|
std::map<std::string, std::vector<uint64_t>> timings;
|
|
860
868
|
};
|
|
861
|
-
#endif // GGML_VULKAN_PERF
|
|
862
869
|
|
|
863
870
|
struct ggml_backend_vk_context {
|
|
864
871
|
std::string name;
|
|
@@ -948,6 +955,8 @@ struct vk_instance_t {
|
|
|
948
955
|
static bool vk_instance_initialized = false;
|
|
949
956
|
static vk_instance_t vk_instance;
|
|
950
957
|
|
|
958
|
+
static bool vk_perf_logger_enabled = false;
|
|
959
|
+
|
|
951
960
|
#ifdef GGML_VULKAN_CHECK_RESULTS
|
|
952
961
|
static size_t vk_skip_checks;
|
|
953
962
|
static size_t vk_output_tensor;
|
|
@@ -1588,19 +1597,36 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
|
|
1588
1597
|
);
|
|
1589
1598
|
}
|
|
1590
1599
|
|
|
1600
|
+
enum FaCodePath {
|
|
1601
|
+
FA_SCALAR,
|
|
1602
|
+
FA_COOPMAT1,
|
|
1603
|
+
FA_COOPMAT2,
|
|
1604
|
+
};
|
|
1605
|
+
|
|
1591
1606
|
// number of rows/cols for flash attention shader
|
|
1592
1607
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
|
1593
1608
|
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
|
1594
1609
|
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
|
|
1595
1610
|
|
|
1596
|
-
|
|
1597
|
-
|
|
1611
|
+
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
|
1612
|
+
// 128 threads split into four subgroups, each subgroup does 1/4
|
|
1613
|
+
// of the Bc dimension.
|
|
1614
|
+
static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
|
|
1615
|
+
static constexpr uint32_t scalar_flash_attention_Bc = 64;
|
|
1616
|
+
static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
|
|
1617
|
+
|
|
1618
|
+
static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
|
1619
|
+
if (path == FA_COOPMAT2) {
|
|
1620
|
+
return flash_attention_num_small_rows;
|
|
1621
|
+
} else {
|
|
1622
|
+
return scalar_flash_attention_num_small_rows;
|
|
1623
|
+
}
|
|
1598
1624
|
}
|
|
1599
1625
|
|
|
1600
|
-
static std::array<uint32_t, 2> fa_rows_cols(
|
|
1626
|
+
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
|
1601
1627
|
GGML_UNUSED(clamp);
|
|
1602
1628
|
|
|
1603
|
-
if (
|
|
1629
|
+
if (path == FA_SCALAR) {
|
|
1604
1630
|
if (small_rows) {
|
|
1605
1631
|
return {scalar_flash_attention_num_small_rows, 64};
|
|
1606
1632
|
} else {
|
|
@@ -1608,9 +1634,17 @@ static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t cl
|
|
|
1608
1634
|
}
|
|
1609
1635
|
}
|
|
1610
1636
|
|
|
1637
|
+
if (path == FA_COOPMAT1) {
|
|
1638
|
+
if (small_rows) {
|
|
1639
|
+
return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
|
|
1640
|
+
} else {
|
|
1641
|
+
return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
|
|
1642
|
+
}
|
|
1643
|
+
}
|
|
1644
|
+
|
|
1611
1645
|
// small rows, large cols
|
|
1612
1646
|
if (small_rows) {
|
|
1613
|
-
return {get_fa_num_small_rows(
|
|
1647
|
+
return {get_fa_num_small_rows(FA_COOPMAT2), 32};
|
|
1614
1648
|
}
|
|
1615
1649
|
|
|
1616
1650
|
// small cols to reduce register count
|
|
@@ -1907,17 +1941,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1907
1941
|
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
|
1908
1942
|
};
|
|
1909
1943
|
|
|
1910
|
-
auto const &fa_wg_denoms = [&](
|
|
1911
|
-
return {fa_rows_cols(
|
|
1944
|
+
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
|
1945
|
+
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
|
|
1912
1946
|
};
|
|
1913
1947
|
|
|
1914
|
-
auto const &fa_spec_constants = [&](
|
|
1948
|
+
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
|
1915
1949
|
// For large number of rows, 128 invocations seems to work best.
|
|
1916
1950
|
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
|
1917
1951
|
// can't use 256 for D==80.
|
|
1918
1952
|
// For scalar, use 128 (arbitrary)
|
|
1919
|
-
uint32_t wg_size =
|
|
1920
|
-
|
|
1953
|
+
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
|
1954
|
+
? scalar_flash_attention_workgroup_size
|
|
1955
|
+
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
|
1956
|
+
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
|
|
1921
1957
|
|
|
1922
1958
|
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
|
1923
1959
|
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
|
@@ -1929,36 +1965,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1929
1965
|
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
|
1930
1966
|
};
|
|
1931
1967
|
|
|
1932
|
-
#define CREATE_FA2(TYPE, NAMELC,
|
|
1933
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(
|
|
1934
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(
|
|
1935
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(
|
|
1936
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(
|
|
1937
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(
|
|
1938
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(
|
|
1939
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(
|
|
1940
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(
|
|
1941
|
-
|
|
1942
|
-
#define CREATE_FA(TYPE, NAMELC,
|
|
1943
|
-
CREATE_FA2(TYPE, NAMELC,
|
|
1944
|
-
CREATE_FA2(TYPE, NAMELC,
|
|
1945
|
-
CREATE_FA2(TYPE, NAMELC,
|
|
1946
|
-
CREATE_FA2(TYPE, NAMELC,
|
|
1947
|
-
CREATE_FA2(TYPE, NAMELC,
|
|
1948
|
-
CREATE_FA2(TYPE, NAMELC,
|
|
1949
|
-
|
|
1950
|
-
CREATE_FA(GGML_TYPE_F16, f16,
|
|
1951
|
-
CREATE_FA(GGML_TYPE_Q4_0, q4_0,
|
|
1952
|
-
CREATE_FA(GGML_TYPE_Q8_0, q8_0,
|
|
1968
|
+
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
|
|
1969
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1970
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1971
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1972
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1973
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1974
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1975
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1976
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1977
|
+
|
|
1978
|
+
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
|
1979
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
|
|
1980
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
|
|
1981
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
|
|
1982
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
|
|
1983
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
|
|
1984
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
|
|
1985
|
+
|
|
1986
|
+
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
|
1987
|
+
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
|
1988
|
+
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
|
1989
|
+
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
1990
|
+
if (device->coopmat1_fa_support) {
|
|
1991
|
+
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
|
1992
|
+
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
|
1993
|
+
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
|
1994
|
+
}
|
|
1995
|
+
#endif
|
|
1953
1996
|
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
1954
1997
|
if (device->coopmat2) {
|
|
1955
|
-
CREATE_FA(GGML_TYPE_F16, f16,
|
|
1956
|
-
CREATE_FA(GGML_TYPE_Q4_0, q4_0,
|
|
1957
|
-
CREATE_FA(GGML_TYPE_Q4_1, q4_1,
|
|
1958
|
-
CREATE_FA(GGML_TYPE_Q5_0, q5_0,
|
|
1959
|
-
CREATE_FA(GGML_TYPE_Q5_1, q5_1,
|
|
1960
|
-
CREATE_FA(GGML_TYPE_Q8_0, q8_0,
|
|
1961
|
-
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl,
|
|
1998
|
+
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
|
1999
|
+
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
|
2000
|
+
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
|
2001
|
+
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
|
|
2002
|
+
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
|
|
2003
|
+
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
|
|
2004
|
+
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
|
|
1962
2005
|
}
|
|
1963
2006
|
#endif
|
|
1964
2007
|
#undef CREATE_FA2
|
|
@@ -1987,25 +2030,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1987
2030
|
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
1988
2031
|
}
|
|
1989
2032
|
#endif
|
|
1990
|
-
|
|
1991
|
-
|
|
1992
|
-
|
|
1993
|
-
|
|
1994
|
-
|
|
1995
|
-
|
|
1996
|
-
|
|
1997
|
-
|
|
1998
|
-
|
|
1999
|
-
|
|
2000
|
-
|
|
2001
|
-
|
|
2002
|
-
|
|
2003
|
-
|
|
2004
|
-
|
|
2005
|
-
|
|
2006
|
-
|
|
2007
|
-
|
|
2008
|
-
|
|
2033
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2034
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2035
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2036
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2037
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2038
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
|
2039
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
|
2040
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
|
2041
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
|
2042
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
|
2043
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S], matmul_iq1_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2044
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M], matmul_iq1_m_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2045
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2046
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2047
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S], matmul_iq2_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2048
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2049
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2050
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2051
|
+
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
2009
2052
|
|
|
2010
2053
|
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
|
2011
2054
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
@@ -2041,17 +2084,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2041
2084
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
2042
2085
|
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
2043
2086
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
2044
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ##
|
|
2087
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
|
2045
2088
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
2046
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ##
|
|
2089
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
|
2047
2090
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2048
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ##
|
|
2091
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
|
2049
2092
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
2050
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ##
|
|
2093
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
|
2051
2094
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
2052
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ##
|
|
2095
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
|
2053
2096
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2054
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ##
|
|
2097
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
|
2055
2098
|
|
|
2056
2099
|
// Create 2 variants, {f16,f32} accumulator
|
|
2057
2100
|
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
@@ -2073,47 +2116,47 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2073
2116
|
#endif
|
|
2074
2117
|
|
|
2075
2118
|
if (device->coopmat_acc_f16_support) {
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
|
|
2080
|
-
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
|
|
2084
|
-
|
|
2085
|
-
|
|
2086
|
-
|
|
2087
|
-
|
|
2088
|
-
|
|
2089
|
-
|
|
2090
|
-
|
|
2091
|
-
|
|
2092
|
-
|
|
2093
|
-
|
|
2094
|
-
|
|
2095
|
-
|
|
2119
|
+
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2120
|
+
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2121
|
+
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2122
|
+
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2123
|
+
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2124
|
+
|
|
2125
|
+
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2126
|
+
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2127
|
+
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2128
|
+
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2129
|
+
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2130
|
+
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2131
|
+
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2132
|
+
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2133
|
+
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2134
|
+
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2135
|
+
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2136
|
+
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2137
|
+
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2138
|
+
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2096
2139
|
} else {
|
|
2097
|
-
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].
|
|
2098
|
-
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].
|
|
2099
|
-
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].
|
|
2100
|
-
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].
|
|
2101
|
-
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].
|
|
2102
|
-
|
|
2103
|
-
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].
|
|
2104
|
-
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].
|
|
2105
|
-
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].
|
|
2106
|
-
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].
|
|
2107
|
-
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].
|
|
2108
|
-
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].
|
|
2109
|
-
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].
|
|
2110
|
-
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].
|
|
2111
|
-
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].
|
|
2112
|
-
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].
|
|
2113
|
-
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].
|
|
2114
|
-
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].
|
|
2115
|
-
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].
|
|
2116
|
-
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].
|
|
2140
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2141
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2142
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2143
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2144
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2145
|
+
|
|
2146
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2147
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2148
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2149
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2150
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2151
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2152
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2153
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2154
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2155
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2156
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2157
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2158
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2159
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2117
2160
|
}
|
|
2118
2161
|
|
|
2119
2162
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
@@ -2188,13 +2231,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2188
2231
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2189
2232
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
|
2190
2233
|
|
|
2191
|
-
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC,
|
|
2192
|
-
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
2193
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC
|
|
2194
|
-
|
|
2195
|
-
|
|
2196
|
-
if (device->mul_mat ## ID ##
|
|
2197
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->
|
|
2234
|
+
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
2235
|
+
if (device->mul_mat ## ID ## _l[TYPE]) { \
|
|
2236
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
2237
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
2238
|
+
} \
|
|
2239
|
+
if (device->mul_mat ## ID ## _m[TYPE]) { \
|
|
2240
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
2241
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
2242
|
+
} \
|
|
2243
|
+
if (device->mul_mat ## ID ## _s[TYPE]) { \
|
|
2244
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
2245
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
2246
|
+
} \
|
|
2198
2247
|
|
|
2199
2248
|
// Create 2 variants, {f16,f32} accumulator
|
|
2200
2249
|
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
@@ -2208,34 +2257,34 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2208
2257
|
|
|
2209
2258
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2210
2259
|
|
|
2211
|
-
|
|
2212
|
-
|
|
2213
|
-
|
|
2214
|
-
|
|
2215
|
-
|
|
2216
|
-
|
|
2217
|
-
|
|
2218
|
-
|
|
2219
|
-
|
|
2220
|
-
|
|
2221
|
-
|
|
2222
|
-
|
|
2223
|
-
|
|
2224
|
-
|
|
2225
|
-
|
|
2226
|
-
|
|
2227
|
-
|
|
2228
|
-
|
|
2229
|
-
|
|
2230
|
-
|
|
2260
|
+
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2261
|
+
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2262
|
+
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2263
|
+
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2264
|
+
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2265
|
+
|
|
2266
|
+
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2267
|
+
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2268
|
+
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2269
|
+
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2270
|
+
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2271
|
+
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2272
|
+
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2273
|
+
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2274
|
+
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2275
|
+
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2276
|
+
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2277
|
+
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2278
|
+
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2279
|
+
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2231
2280
|
|
|
2232
2281
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
2233
2282
|
if (device->integer_dot_product) {
|
|
2234
|
-
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0]
|
|
2235
|
-
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1]
|
|
2236
|
-
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0]
|
|
2237
|
-
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1]
|
|
2238
|
-
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0]
|
|
2283
|
+
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2284
|
+
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2285
|
+
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2286
|
+
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2287
|
+
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2239
2288
|
}
|
|
2240
2289
|
#endif
|
|
2241
2290
|
|
|
@@ -2284,13 +2333,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2284
2333
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2285
2334
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
|
2286
2335
|
|
|
2287
|
-
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC,
|
|
2336
|
+
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
2288
2337
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
2289
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC
|
|
2338
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
2290
2339
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
2291
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC
|
|
2340
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
2292
2341
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2293
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC
|
|
2342
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
2294
2343
|
|
|
2295
2344
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2296
2345
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
@@ -2322,11 +2371,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2322
2371
|
|
|
2323
2372
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
2324
2373
|
if (device->integer_dot_product) {
|
|
2325
|
-
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1,
|
|
2326
|
-
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1,
|
|
2327
|
-
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1,
|
|
2328
|
-
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1,
|
|
2329
|
-
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1,
|
|
2374
|
+
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2375
|
+
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2376
|
+
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2377
|
+
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2378
|
+
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2330
2379
|
}
|
|
2331
2380
|
#endif
|
|
2332
2381
|
|
|
@@ -2707,9 +2756,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2707
2756
|
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
|
2708
2757
|
device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());
|
|
2709
2758
|
#endif
|
|
2710
|
-
|
|
2711
|
-
|
|
2712
|
-
|
|
2759
|
+
if (vk_perf_logger_enabled) {
|
|
2760
|
+
device->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
|
|
2761
|
+
}
|
|
2713
2762
|
|
|
2714
2763
|
size_t dev_num = vk_instance.device_indices[idx];
|
|
2715
2764
|
|
|
@@ -2754,23 +2803,29 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2754
2803
|
pipeline_robustness = true;
|
|
2755
2804
|
} else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
|
|
2756
2805
|
device->subgroup_size_control = true;
|
|
2806
|
+
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
2757
2807
|
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
|
|
2758
2808
|
!getenv("GGML_VK_DISABLE_COOPMAT")) {
|
|
2759
2809
|
device->coopmat_support = true;
|
|
2760
2810
|
device->coopmat_m = 0;
|
|
2761
2811
|
device->coopmat_n = 0;
|
|
2762
2812
|
device->coopmat_k = 0;
|
|
2813
|
+
#endif
|
|
2814
|
+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
2763
2815
|
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
|
2764
2816
|
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
|
2765
2817
|
coopmat2_support = true;
|
|
2818
|
+
#endif
|
|
2766
2819
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
2767
2820
|
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
|
2768
2821
|
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
|
2769
2822
|
device->integer_dot_product = true;
|
|
2770
2823
|
#endif
|
|
2824
|
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
2771
2825
|
} else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
|
|
2772
2826
|
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
|
2773
2827
|
bfloat16_support = true;
|
|
2828
|
+
#endif
|
|
2774
2829
|
}
|
|
2775
2830
|
}
|
|
2776
2831
|
|
|
@@ -3009,6 +3064,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
3009
3064
|
|
|
3010
3065
|
#if defined(VK_KHR_cooperative_matrix)
|
|
3011
3066
|
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
|
|
3067
|
+
|
|
3068
|
+
// coopmat1 fa shader currently assumes 32 invocations per subgroup
|
|
3069
|
+
device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
|
|
3070
|
+
device->subgroup_size_control && device->subgroup_min_size <= 32 &&
|
|
3071
|
+
device->subgroup_max_size >= 32;
|
|
3012
3072
|
#endif
|
|
3013
3073
|
|
|
3014
3074
|
if (coopmat2_support) {
|
|
@@ -3143,6 +3203,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
3143
3203
|
// Only enable if shape is identical
|
|
3144
3204
|
device->coopmat_acc_f32_support = true;
|
|
3145
3205
|
}
|
|
3206
|
+
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
|
|
3207
|
+
device->coopmat_support_16x16x16_f32acc = true;
|
|
3208
|
+
}
|
|
3146
3209
|
} else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
|
|
3147
3210
|
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
|
|
3148
3211
|
// coopmat sizes not set yet
|
|
@@ -3155,6 +3218,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
3155
3218
|
// Only enable if shape is identical
|
|
3156
3219
|
device->coopmat_acc_f16_support = true;
|
|
3157
3220
|
}
|
|
3221
|
+
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
|
|
3222
|
+
device->coopmat_support_16x16x16_f16acc = true;
|
|
3223
|
+
}
|
|
3158
3224
|
}
|
|
3159
3225
|
} else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
|
|
3160
3226
|
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
|
|
@@ -3480,6 +3546,8 @@ static void ggml_vk_instance_init() {
|
|
|
3480
3546
|
vk_instance.instance = vk::createInstance(instance_create_info);
|
|
3481
3547
|
vk_instance_initialized = true;
|
|
3482
3548
|
|
|
3549
|
+
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
|
3550
|
+
|
|
3483
3551
|
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
|
|
3484
3552
|
|
|
3485
3553
|
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
|
@@ -3656,7 +3724,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
|
|
3656
3724
|
}
|
|
3657
3725
|
|
|
3658
3726
|
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
|
|
3659
|
-
VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
|
3727
|
+
VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ", " << prec << ")");
|
|
3660
3728
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
|
3661
3729
|
return ctx->device->pipeline_matmul_f32;
|
|
3662
3730
|
}
|
|
@@ -3684,7 +3752,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
|
3684
3752
|
|
|
3685
3753
|
// MMQ
|
|
3686
3754
|
if (src1_type == GGML_TYPE_Q8_1) {
|
|
3687
|
-
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc;
|
|
3755
|
+
vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
|
|
3688
3756
|
|
|
3689
3757
|
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
|
3690
3758
|
return nullptr;
|
|
@@ -3724,9 +3792,12 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
|
3724
3792
|
|
|
3725
3793
|
if (ctx->device->coopmat2) {
|
|
3726
3794
|
assert(src1_type == GGML_TYPE_F16);
|
|
3727
|
-
return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
|
|
3795
|
+
return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc;
|
|
3728
3796
|
}
|
|
3729
|
-
|
|
3797
|
+
if (ctx->device->coopmat_support) {
|
|
3798
|
+
return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
|
3799
|
+
}
|
|
3800
|
+
return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
|
3730
3801
|
}
|
|
3731
3802
|
|
|
3732
3803
|
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) {
|
|
@@ -4449,6 +4520,8 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
|
|
|
4449
4520
|
return aligned ? mmp->a_m : mmp->m;
|
|
4450
4521
|
}
|
|
4451
4522
|
return aligned ? mmp->a_l : mmp->l;
|
|
4523
|
+
|
|
4524
|
+
GGML_UNUSED(src1_type);
|
|
4452
4525
|
}
|
|
4453
4526
|
|
|
4454
4527
|
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
|
|
@@ -4604,6 +4677,19 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
4604
4677
|
}
|
|
4605
4678
|
}
|
|
4606
4679
|
|
|
4680
|
+
if (src->type == to) {
|
|
4681
|
+
// Copy two or four bytes at a time, depending on block size.
|
|
4682
|
+
// For quantized types, we scale by block size/type size. But
|
|
4683
|
+
// this path is also used for bf16->bf16 for example, where the
|
|
4684
|
+
// type size must be exactly 2 or 4.
|
|
4685
|
+
GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
|
|
4686
|
+
if ((ggml_type_size(src->type) % 4) == 0) {
|
|
4687
|
+
return ctx->device->pipeline_contig_cpy_f32_f32;
|
|
4688
|
+
} else {
|
|
4689
|
+
return ctx->device->pipeline_contig_cpy_f16_f16;
|
|
4690
|
+
}
|
|
4691
|
+
}
|
|
4692
|
+
|
|
4607
4693
|
std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
|
|
4608
4694
|
GGML_ABORT("fatal error");
|
|
4609
4695
|
}
|
|
@@ -5688,6 +5774,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5688
5774
|
}
|
|
5689
5775
|
}
|
|
5690
5776
|
|
|
5777
|
+
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
|
|
5778
|
+
// Needs to be kept up to date on shader changes
|
|
5779
|
+
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
|
5780
|
+
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
|
5781
|
+
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
5782
|
+
|
|
5783
|
+
const uint32_t acctype = f32acc ? 4 : 2;
|
|
5784
|
+
const uint32_t f16vec4 = 8;
|
|
5785
|
+
|
|
5786
|
+
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
5787
|
+
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
|
5788
|
+
|
|
5789
|
+
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
|
|
5790
|
+
|
|
5791
|
+
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
|
5792
|
+
const uint32_t sfsh = Bc * sfshstride * acctype;
|
|
5793
|
+
|
|
5794
|
+
const uint32_t kshstride = D / 4 + 2;
|
|
5795
|
+
const uint32_t ksh = Bc * kshstride * f16vec4;
|
|
5796
|
+
|
|
5797
|
+
const uint32_t slope = Br * sizeof(float);
|
|
5798
|
+
|
|
5799
|
+
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
|
5800
|
+
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
5801
|
+
|
|
5802
|
+
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
|
5803
|
+
|
|
5804
|
+
return supported;
|
|
5805
|
+
}
|
|
5806
|
+
|
|
5691
5807
|
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
|
|
5692
5808
|
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
|
|
5693
5809
|
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
|
|
@@ -5738,7 +5854,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5738
5854
|
assert(q->type == GGML_TYPE_F32);
|
|
5739
5855
|
assert(k->type == v->type);
|
|
5740
5856
|
|
|
5741
|
-
|
|
5857
|
+
FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
|
|
5858
|
+
ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
|
5859
|
+
|
|
5860
|
+
if (path == FA_COOPMAT1) {
|
|
5861
|
+
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
|
5862
|
+
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
|
5863
|
+
|
|
5864
|
+
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
|
|
5865
|
+
|
|
5866
|
+
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
|
5867
|
+
path = FA_SCALAR;
|
|
5868
|
+
}
|
|
5869
|
+
}
|
|
5742
5870
|
|
|
5743
5871
|
uint32_t gqa_ratio = 1;
|
|
5744
5872
|
uint32_t qk_ratio = neq2 / nek2;
|
|
@@ -5746,9 +5874,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5746
5874
|
uint32_t workgroups_y = (uint32_t)neq2;
|
|
5747
5875
|
uint32_t workgroups_z = (uint32_t)neq3;
|
|
5748
5876
|
|
|
5749
|
-
// For scalar FA, we can use the "large" size to accommodate qga.
|
|
5750
|
-
// For
|
|
5751
|
-
|
|
5877
|
+
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
|
5878
|
+
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
|
5879
|
+
uint32_t max_gqa;
|
|
5880
|
+
switch (path) {
|
|
5881
|
+
case FA_SCALAR:
|
|
5882
|
+
case FA_COOPMAT1:
|
|
5883
|
+
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
|
5884
|
+
max_gqa = scalar_flash_attention_num_large_rows;
|
|
5885
|
+
break;
|
|
5886
|
+
case FA_COOPMAT2:
|
|
5887
|
+
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
|
5888
|
+
break;
|
|
5889
|
+
default:
|
|
5890
|
+
GGML_ASSERT(0);
|
|
5891
|
+
}
|
|
5752
5892
|
|
|
5753
5893
|
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
|
5754
5894
|
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
|
@@ -5761,11 +5901,23 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5761
5901
|
}
|
|
5762
5902
|
|
|
5763
5903
|
vk_pipeline *pipelines;
|
|
5764
|
-
|
|
5765
|
-
|
|
5766
|
-
|
|
5904
|
+
bool small_rows = N <= get_fa_num_small_rows(path);
|
|
5905
|
+
|
|
5906
|
+
// coopmat1 does not actually support "small rows" (it needs 16 rows).
|
|
5907
|
+
// So use scalar instead.
|
|
5908
|
+
if (small_rows && path == FA_COOPMAT1) {
|
|
5909
|
+
path = FA_SCALAR;
|
|
5910
|
+
}
|
|
5767
5911
|
|
|
5768
|
-
|
|
5912
|
+
// scalar is faster than coopmat2 when N==1
|
|
5913
|
+
if (N == 1 && path == FA_COOPMAT2) {
|
|
5914
|
+
path = FA_SCALAR;
|
|
5915
|
+
}
|
|
5916
|
+
|
|
5917
|
+
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
|
5918
|
+
|
|
5919
|
+
switch (path) {
|
|
5920
|
+
case FA_SCALAR:
|
|
5769
5921
|
switch (D) {
|
|
5770
5922
|
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
|
5771
5923
|
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
|
@@ -5777,7 +5929,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5777
5929
|
GGML_ASSERT(!"unsupported D value");
|
|
5778
5930
|
return;
|
|
5779
5931
|
}
|
|
5780
|
-
|
|
5932
|
+
break;
|
|
5933
|
+
case FA_COOPMAT1:
|
|
5934
|
+
switch (D) {
|
|
5935
|
+
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5936
|
+
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5937
|
+
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5938
|
+
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5939
|
+
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5940
|
+
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5941
|
+
default:
|
|
5942
|
+
GGML_ASSERT(!"unsupported D value");
|
|
5943
|
+
return;
|
|
5944
|
+
}
|
|
5945
|
+
break;
|
|
5946
|
+
case FA_COOPMAT2:
|
|
5781
5947
|
switch (D) {
|
|
5782
5948
|
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
|
5783
5949
|
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
|
@@ -5789,6 +5955,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5789
5955
|
GGML_ASSERT(!"unsupported D value");
|
|
5790
5956
|
return;
|
|
5791
5957
|
}
|
|
5958
|
+
break;
|
|
5959
|
+
default:
|
|
5960
|
+
GGML_ASSERT(0);
|
|
5792
5961
|
}
|
|
5793
5962
|
assert(pipelines);
|
|
5794
5963
|
|
|
@@ -6284,6 +6453,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
|
6284
6453
|
case GGML_OP_ROPE:
|
|
6285
6454
|
case GGML_OP_RMS_NORM:
|
|
6286
6455
|
case GGML_OP_CONV_2D_DW:
|
|
6456
|
+
case GGML_OP_IM2COL:
|
|
6287
6457
|
return true;
|
|
6288
6458
|
default:
|
|
6289
6459
|
return false;
|
|
@@ -6582,7 +6752,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6582
6752
|
case GGML_OP_UNARY:
|
|
6583
6753
|
case GGML_OP_CONV_2D_DW:
|
|
6584
6754
|
{
|
|
6585
|
-
|
|
6755
|
+
uint32_t ne = ggml_nelements(dst);
|
|
6756
|
+
if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
|
6757
|
+
// Convert from number of logical elements to 2- or 4-byte units.
|
|
6758
|
+
ne /= ggml_blck_size(src0->type);
|
|
6759
|
+
if ((ggml_type_size(src0->type) % 4) == 0) {
|
|
6760
|
+
ne *= ggml_type_size(src0->type) / 4;
|
|
6761
|
+
} else {
|
|
6762
|
+
ne *= ggml_type_size(src0->type) / 2;
|
|
6763
|
+
}
|
|
6764
|
+
}
|
|
6586
6765
|
if (ne > 262144) {
|
|
6587
6766
|
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
|
6588
6767
|
} else if (ne > 512) {
|
|
@@ -7132,8 +7311,19 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
7132
7311
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7133
7312
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7134
7313
|
|
|
7314
|
+
uint32_t ne = (uint32_t)ggml_nelements(src0);
|
|
7315
|
+
if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
|
7316
|
+
// Convert from number of logical elements to 2- or 4-byte units.
|
|
7317
|
+
ne /= ggml_blck_size(src0->type);
|
|
7318
|
+
if ((ggml_type_size(src0->type) % 4) == 0) {
|
|
7319
|
+
ne *= ggml_type_size(src0->type) / 4;
|
|
7320
|
+
} else {
|
|
7321
|
+
ne *= ggml_type_size(src0->type) / 2;
|
|
7322
|
+
}
|
|
7323
|
+
}
|
|
7324
|
+
|
|
7135
7325
|
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
|
|
7136
|
-
|
|
7326
|
+
ne,
|
|
7137
7327
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7138
7328
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7139
7329
|
0,
|
|
@@ -8696,7 +8886,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8696
8886
|
|
|
8697
8887
|
ctx->tensor_ctxs[node_idx] = compute_ctx;
|
|
8698
8888
|
|
|
8699
|
-
#if defined(GGML_VULKAN_CHECK_RESULTS)
|
|
8889
|
+
#if defined(GGML_VULKAN_CHECK_RESULTS)
|
|
8700
8890
|
// Force context reset on each node so that each tensor ends up in its own context
|
|
8701
8891
|
// and can be run and compared to its CPU equivalent separately
|
|
8702
8892
|
last_node = true;
|
|
@@ -9115,8 +9305,7 @@ static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_
|
|
|
9115
9305
|
try {
|
|
9116
9306
|
ptr = ggml_vk_host_malloc(vk_instance.devices[0], size);
|
|
9117
9307
|
} catch (vk::SystemError& e) {
|
|
9118
|
-
|
|
9119
|
-
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
|
|
9308
|
+
GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what());
|
|
9120
9309
|
// fallback to cpu buffer
|
|
9121
9310
|
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
|
9122
9311
|
}
|
|
@@ -9317,6 +9506,29 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9317
9506
|
bool first_node_in_batch = true; // true if next node will be first node in a batch
|
|
9318
9507
|
int submit_node_idx = 0; // index to first node in a batch
|
|
9319
9508
|
|
|
9509
|
+
vk_context compute_ctx;
|
|
9510
|
+
if (vk_perf_logger_enabled) {
|
|
9511
|
+
// allocate/resize the query pool
|
|
9512
|
+
if (ctx->device->num_queries < cgraph->n_nodes + 1) {
|
|
9513
|
+
if (ctx->device->query_pool) {
|
|
9514
|
+
ctx->device->device.destroyQueryPool(ctx->device->query_pool);
|
|
9515
|
+
}
|
|
9516
|
+
VkQueryPoolCreateInfo query_create_info = { VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO };
|
|
9517
|
+
query_create_info.queryType = VK_QUERY_TYPE_TIMESTAMP;
|
|
9518
|
+
query_create_info.queryCount = cgraph->n_nodes + 100;
|
|
9519
|
+
ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info);
|
|
9520
|
+
ctx->device->num_queries = query_create_info.queryCount;
|
|
9521
|
+
}
|
|
9522
|
+
|
|
9523
|
+
ctx->device->device.resetQueryPool(ctx->device->query_pool, 0, cgraph->n_nodes+1);
|
|
9524
|
+
|
|
9525
|
+
GGML_ASSERT(ctx->compute_ctx.expired());
|
|
9526
|
+
compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
|
9527
|
+
ctx->compute_ctx = compute_ctx;
|
|
9528
|
+
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
|
9529
|
+
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
|
|
9530
|
+
}
|
|
9531
|
+
|
|
9320
9532
|
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
|
|
9321
9533
|
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
|
|
9322
9534
|
// (and scaled down based on model size, so smaller models submit earlier).
|
|
@@ -9344,6 +9556,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9344
9556
|
|
|
9345
9557
|
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
|
|
9346
9558
|
|
|
9559
|
+
if (vk_perf_logger_enabled) {
|
|
9560
|
+
if (ctx->compute_ctx.expired()) {
|
|
9561
|
+
compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
|
9562
|
+
ctx->compute_ctx = compute_ctx;
|
|
9563
|
+
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
|
9564
|
+
} else {
|
|
9565
|
+
compute_ctx = ctx->compute_ctx.lock();
|
|
9566
|
+
}
|
|
9567
|
+
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
|
|
9568
|
+
}
|
|
9569
|
+
|
|
9347
9570
|
if (enqueued) {
|
|
9348
9571
|
++submitted_nodes;
|
|
9349
9572
|
|
|
@@ -9365,9 +9588,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
9365
9588
|
}
|
|
9366
9589
|
}
|
|
9367
9590
|
|
|
9368
|
-
|
|
9369
|
-
|
|
9370
|
-
|
|
9591
|
+
if (vk_perf_logger_enabled) {
|
|
9592
|
+
// End the command buffer and submit/wait
|
|
9593
|
+
GGML_ASSERT(!ctx->compute_ctx.expired());
|
|
9594
|
+
compute_ctx = ctx->compute_ctx.lock();
|
|
9595
|
+
ggml_vk_ctx_end(compute_ctx);
|
|
9596
|
+
|
|
9597
|
+
ggml_vk_submit(compute_ctx, ctx->device->fence);
|
|
9598
|
+
VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences");
|
|
9599
|
+
ctx->device->device.resetFences({ ctx->device->fence });
|
|
9600
|
+
|
|
9601
|
+
// Get the results and pass them to the logger
|
|
9602
|
+
std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
|
|
9603
|
+
ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait);
|
|
9604
|
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
9605
|
+
if (!ggml_vk_is_empty(cgraph->nodes[i])) {
|
|
9606
|
+
ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod));
|
|
9607
|
+
}
|
|
9608
|
+
}
|
|
9609
|
+
|
|
9610
|
+
ctx->device->perf_logger->print_timings();
|
|
9611
|
+
}
|
|
9371
9612
|
|
|
9372
9613
|
ggml_vk_graph_cleanup(ctx);
|
|
9373
9614
|
|
|
@@ -9718,6 +9959,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9718
9959
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
|
9719
9960
|
return true;
|
|
9720
9961
|
}
|
|
9962
|
+
|
|
9963
|
+
// We can handle copying from a type to the same type if it's
|
|
9964
|
+
// contiguous (memcpy). We use f16 or f32 shaders to do the copy,
|
|
9965
|
+
// so the type/block size must be a multiple of 4.
|
|
9966
|
+
if (src0_type == src1_type &&
|
|
9967
|
+
ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) &&
|
|
9968
|
+
(ggml_type_size(src0_type) % 2) == 0) {
|
|
9969
|
+
return true;
|
|
9970
|
+
}
|
|
9721
9971
|
return false;
|
|
9722
9972
|
} break;
|
|
9723
9973
|
case GGML_OP_REPEAT:
|
|
@@ -10123,7 +10373,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
10123
10373
|
} else if (tensor->op == GGML_OP_CONCAT) {
|
|
10124
10374
|
tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
|
|
10125
10375
|
} else if (tensor->op == GGML_OP_UPSCALE) {
|
|
10126
|
-
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3],
|
|
10376
|
+
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
|
|
10127
10377
|
} else if (tensor->op == GGML_OP_SCALE) {
|
|
10128
10378
|
const float * params = (const float *)tensor->op_params;
|
|
10129
10379
|
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
|
|
@@ -10412,7 +10662,8 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|
|
10412
10662
|
ggml_vk_print_graph_origin(tensor, done);
|
|
10413
10663
|
GGML_ABORT("fatal error");
|
|
10414
10664
|
}
|
|
10415
|
-
|
|
10665
|
+
const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f;
|
|
10666
|
+
if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) {
|
|
10416
10667
|
first_error[0] = i0;
|
|
10417
10668
|
first_error[1] = i1;
|
|
10418
10669
|
first_error[2] = i2;
|
|
@@ -10424,7 +10675,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|
|
10424
10675
|
// Special case, value is infinite, avoid NaN result in avg_err
|
|
10425
10676
|
// NaN also appears in results, if both are nan error is 0
|
|
10426
10677
|
if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) {
|
|
10427
|
-
avg_err += std::fabs(correct - result);
|
|
10678
|
+
avg_err += std::fabs(correct - result) / denom;
|
|
10428
10679
|
}
|
|
10429
10680
|
counter++;
|
|
10430
10681
|
}
|
|
@@ -10459,7 +10710,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|
|
10459
10710
|
ggml_vk_print_graph_origin(tensor, done);
|
|
10460
10711
|
}
|
|
10461
10712
|
|
|
10462
|
-
if (avg_err > 0.
|
|
10713
|
+
if (avg_err > 0.5 || std::isnan(avg_err)) {
|
|
10463
10714
|
std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
|
|
10464
10715
|
std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
|
|
10465
10716
|
if (src0 != nullptr) {
|