@fugood/llama.node 0.3.3 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +5 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +18 -1
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +15 -5
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +1 -1
- package/src/LlamaContext.cpp +81 -18
- package/src/LlamaContext.h +2 -0
- package/src/llama.cpp/.github/workflows/build.yml +197 -159
- package/src/llama.cpp/.github/workflows/docker.yml +5 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +11 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -2
- package/src/llama.cpp/common/arg.cpp +426 -245
- package/src/llama.cpp/common/common.cpp +143 -80
- package/src/llama.cpp/common/common.h +81 -24
- package/src/llama.cpp/common/sampling.cpp +53 -19
- package/src/llama.cpp/common/sampling.h +22 -1
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +101 -148
- package/src/llama.cpp/examples/CMakeLists.txt +32 -13
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +5 -4
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +262 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/llava.cpp +46 -19
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
- package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +9 -5
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
- package/src/llama.cpp/examples/server/server.cpp +1758 -886
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +94 -304
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +4 -0
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
- package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +106 -24
- package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
- package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
- package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
- package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
- package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
- package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
- package/src/llama.cpp/ggml/src/ggml.c +367 -207
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +26 -19
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/src/CMakeLists.txt +2 -7
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +35 -90
- package/src/llama.cpp/src/llama-vocab.cpp +6 -1
- package/src/llama.cpp/src/llama.cpp +1748 -640
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -37
- package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
- package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
- package/src/llama.cpp/tests/test-rope.cpp +61 -20
- package/src/llama.cpp/tests/test-sampling.cpp +2 -2
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
#include "ggml-vulkan.h"
|
|
2
2
|
#include <vulkan/vulkan_core.h>
|
|
3
|
-
#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF)
|
|
3
|
+
#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS)
|
|
4
4
|
#include <chrono>
|
|
5
|
+
#include "ggml-cpu.h"
|
|
5
6
|
#endif
|
|
6
7
|
|
|
7
8
|
#include <vulkan/vulkan.hpp>
|
|
@@ -43,12 +44,6 @@
|
|
|
43
44
|
|
|
44
45
|
#define MAX_VK_BUFFERS 256
|
|
45
46
|
|
|
46
|
-
#ifndef K_QUANTS_PER_ITERATION
|
|
47
|
-
#define K_QUANTS_PER_ITERATION 1
|
|
48
|
-
#else
|
|
49
|
-
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
|
|
50
|
-
#endif
|
|
51
|
-
|
|
52
47
|
#define VK_CHECK(err, msg) \
|
|
53
48
|
do { \
|
|
54
49
|
vk::Result err_ = (err); \
|
|
@@ -158,29 +153,53 @@ struct vk_device_struct {
|
|
|
158
153
|
std::string name;
|
|
159
154
|
uint64_t max_memory_allocation_size;
|
|
160
155
|
bool fp16;
|
|
156
|
+
bool pipeline_robustness;
|
|
161
157
|
vk::Device device;
|
|
162
158
|
uint32_t vendor_id;
|
|
163
159
|
vk_queue compute_queue;
|
|
164
160
|
vk_queue transfer_queue;
|
|
165
161
|
bool single_queue;
|
|
166
162
|
uint32_t subgroup_size;
|
|
163
|
+
uint32_t shader_core_count;
|
|
167
164
|
bool uma;
|
|
165
|
+
bool float_controls_rte_fp16;
|
|
166
|
+
|
|
167
|
+
bool subgroup_size_control;
|
|
168
|
+
uint32_t subgroup_min_size;
|
|
169
|
+
uint32_t subgroup_max_size;
|
|
170
|
+
bool subgroup_require_full_support;
|
|
171
|
+
|
|
172
|
+
bool coopmat_support;
|
|
173
|
+
bool coopmat_acc_f32_support;
|
|
174
|
+
bool coopmat_acc_f16_support;
|
|
175
|
+
uint32_t coopmat_m;
|
|
176
|
+
uint32_t coopmat_n;
|
|
177
|
+
uint32_t coopmat_k;
|
|
178
|
+
bool coopmat2;
|
|
168
179
|
|
|
169
180
|
size_t idx;
|
|
170
181
|
|
|
182
|
+
bool mul_mat_l;
|
|
183
|
+
bool mul_mat_m;
|
|
184
|
+
bool mul_mat_s;
|
|
185
|
+
bool mul_mat_id_l;
|
|
186
|
+
bool mul_mat_id_m;
|
|
187
|
+
bool mul_mat_id_s;
|
|
188
|
+
|
|
171
189
|
vk_matmul_pipeline pipeline_matmul_f32;
|
|
172
190
|
vk_matmul_pipeline pipeline_matmul_f32_f16;
|
|
173
191
|
vk_matmul_pipeline2 pipeline_matmul_f16;
|
|
174
192
|
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
|
175
193
|
vk_pipeline pipeline_matmul_split_k_reduce;
|
|
176
194
|
|
|
195
|
+
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
|
|
177
196
|
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
|
|
178
197
|
|
|
179
198
|
vk_matmul_pipeline pipeline_matmul_id_f32;
|
|
180
|
-
|
|
181
|
-
|
|
199
|
+
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
|
200
|
+
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
|
182
201
|
|
|
183
|
-
|
|
202
|
+
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
|
|
184
203
|
|
|
185
204
|
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
|
186
205
|
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
|
|
@@ -218,6 +237,7 @@ struct vk_device_struct {
|
|
|
218
237
|
vk_pipeline pipeline_tanh_f32;
|
|
219
238
|
vk_pipeline pipeline_diag_mask_inf_f32;
|
|
220
239
|
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
|
240
|
+
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
|
|
221
241
|
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
|
|
222
242
|
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
|
|
223
243
|
vk_pipeline pipeline_argsort_f32;
|
|
@@ -225,6 +245,15 @@ struct vk_device_struct {
|
|
|
225
245
|
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
|
226
246
|
vk_pipeline pipeline_timestep_embedding_f32;
|
|
227
247
|
vk_pipeline pipeline_pool2d_f32;
|
|
248
|
+
vk_pipeline pipeline_rwkv_wkv6_f32;
|
|
249
|
+
|
|
250
|
+
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
|
251
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
|
252
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
|
253
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
|
254
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
|
255
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
|
256
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
|
228
257
|
|
|
229
258
|
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
|
230
259
|
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
|
|
@@ -337,6 +366,40 @@ struct vk_mat_vec_id_push_constants {
|
|
|
337
366
|
uint32_t nei0; uint32_t ne11;
|
|
338
367
|
};
|
|
339
368
|
|
|
369
|
+
struct vk_flash_attn_push_constants {
|
|
370
|
+
uint32_t N;
|
|
371
|
+
uint32_t KV;
|
|
372
|
+
|
|
373
|
+
uint32_t ne1;
|
|
374
|
+
uint32_t ne2;
|
|
375
|
+
uint32_t ne3;
|
|
376
|
+
|
|
377
|
+
uint32_t neq2;
|
|
378
|
+
uint32_t neq3;
|
|
379
|
+
uint32_t nek2;
|
|
380
|
+
uint32_t nek3;
|
|
381
|
+
uint32_t nev2;
|
|
382
|
+
uint32_t nev3;
|
|
383
|
+
uint32_t nem1;
|
|
384
|
+
|
|
385
|
+
uint32_t nb02;
|
|
386
|
+
uint32_t nb03;
|
|
387
|
+
uint32_t nb12;
|
|
388
|
+
uint32_t nb13;
|
|
389
|
+
uint32_t nb22;
|
|
390
|
+
uint32_t nb23;
|
|
391
|
+
uint32_t nb31;
|
|
392
|
+
|
|
393
|
+
float scale;
|
|
394
|
+
float max_bias;
|
|
395
|
+
float logit_softcap;
|
|
396
|
+
|
|
397
|
+
uint32_t mask;
|
|
398
|
+
uint32_t n_head_log2;
|
|
399
|
+
float m0;
|
|
400
|
+
float m1;
|
|
401
|
+
};
|
|
402
|
+
|
|
340
403
|
struct vk_op_push_constants {
|
|
341
404
|
uint32_t KX;
|
|
342
405
|
uint32_t KY;
|
|
@@ -350,7 +413,46 @@ struct vk_op_unary_push_constants {
|
|
|
350
413
|
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
|
351
414
|
uint32_t d_offset;
|
|
352
415
|
float param1; float param2;
|
|
416
|
+
uint32_t ne0_012mp; uint32_t ne0_012L;
|
|
417
|
+
uint32_t ne0_01mp; uint32_t ne0_01L;
|
|
418
|
+
uint32_t ne0_0mp; uint32_t ne0_0L;
|
|
419
|
+
uint32_t ne1_012mp; uint32_t ne1_012L;
|
|
420
|
+
uint32_t ne1_01mp; uint32_t ne1_01L;
|
|
421
|
+
uint32_t ne1_0mp; uint32_t ne1_0L;
|
|
353
422
|
};
|
|
423
|
+
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
|
|
424
|
+
|
|
425
|
+
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
|
426
|
+
// Precompute mp (m' in the paper) and L such that division
|
|
427
|
+
// can be computed using a multiply (high 32b of 64b result)
|
|
428
|
+
// and a shift:
|
|
429
|
+
//
|
|
430
|
+
// n/d = (mulhi(n, mp) + n) >> L;
|
|
431
|
+
static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L)
|
|
432
|
+
{
|
|
433
|
+
// compute L = ceil(log2(d));
|
|
434
|
+
L = 0;
|
|
435
|
+
while (L < 32 && (uint32_t{1} << L) < d) {
|
|
436
|
+
L++;
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1);
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
template <typename T> void init_pushconst_fastdiv(T &p) {
|
|
443
|
+
GGML_UNUSED(p);
|
|
444
|
+
static_assert(!std::is_const<T>::value, "unexpected type");
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) {
|
|
448
|
+
// Compute magic values to divide by these six numbers.
|
|
449
|
+
init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L);
|
|
450
|
+
init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L);
|
|
451
|
+
init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L);
|
|
452
|
+
init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L);
|
|
453
|
+
init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L);
|
|
454
|
+
init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L);
|
|
455
|
+
}
|
|
354
456
|
|
|
355
457
|
struct vk_op_binary_push_constants {
|
|
356
458
|
uint32_t ne;
|
|
@@ -388,6 +490,7 @@ struct vk_op_soft_max_push_constants {
|
|
|
388
490
|
float m0;
|
|
389
491
|
float m1;
|
|
390
492
|
uint32_t n_head_log2;
|
|
493
|
+
uint32_t nrows_x;
|
|
391
494
|
};
|
|
392
495
|
|
|
393
496
|
struct vk_op_argsort_push_constants {
|
|
@@ -426,6 +529,13 @@ struct vk_op_pool2d_push_constants {
|
|
|
426
529
|
int32_t p0; int32_t p1;
|
|
427
530
|
};
|
|
428
531
|
|
|
532
|
+
struct vk_op_rwkv_wkv6_push_constants {
|
|
533
|
+
uint32_t B;
|
|
534
|
+
uint32_t T;
|
|
535
|
+
uint32_t C;
|
|
536
|
+
uint32_t H;
|
|
537
|
+
};
|
|
538
|
+
|
|
429
539
|
// Allow pre-recording command buffers
|
|
430
540
|
struct vk_staging_memcpy {
|
|
431
541
|
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
|
@@ -652,8 +762,12 @@ static uint32_t compile_count = 0;
|
|
|
652
762
|
static std::mutex compile_count_mutex;
|
|
653
763
|
static std::condition_variable compile_count_cond;
|
|
654
764
|
|
|
655
|
-
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint,
|
|
656
|
-
|
|
765
|
+
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint,
|
|
766
|
+
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
|
|
767
|
+
uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
|
|
768
|
+
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size <<
|
|
769
|
+
", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align <<
|
|
770
|
+
", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
|
|
657
771
|
GGML_ASSERT(parameter_count > 0);
|
|
658
772
|
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
|
|
659
773
|
|
|
@@ -712,16 +826,39 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
|
|
712
826
|
specialization_constants.data()
|
|
713
827
|
);
|
|
714
828
|
|
|
829
|
+
vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{};
|
|
830
|
+
|
|
831
|
+
if (device->subgroup_require_full_support && require_full_subgroups) {
|
|
832
|
+
pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT;
|
|
833
|
+
}
|
|
834
|
+
|
|
715
835
|
vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
|
|
716
|
-
|
|
836
|
+
pipeline_shader_stage_create_flags,
|
|
717
837
|
vk::ShaderStageFlagBits::eCompute,
|
|
718
838
|
pipeline->shader_module,
|
|
719
839
|
entrypoint.c_str(),
|
|
720
840
|
&specialization_info);
|
|
841
|
+
|
|
842
|
+
vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info;
|
|
843
|
+
pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size;
|
|
844
|
+
if (device->subgroup_size_control && required_subgroup_size > 0) {
|
|
845
|
+
GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size);
|
|
846
|
+
pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info);
|
|
847
|
+
}
|
|
848
|
+
|
|
721
849
|
vk::ComputePipelineCreateInfo compute_pipeline_create_info(
|
|
722
|
-
vk::PipelineCreateFlags
|
|
850
|
+
vk::PipelineCreateFlags{},
|
|
723
851
|
pipeline_shader_create_info,
|
|
724
852
|
pipeline->layout);
|
|
853
|
+
|
|
854
|
+
vk::PipelineRobustnessCreateInfoEXT rci;
|
|
855
|
+
|
|
856
|
+
if (device->pipeline_robustness && disable_robustness) {
|
|
857
|
+
rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
|
|
858
|
+
rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
|
|
859
|
+
compute_pipeline_create_info.setPNext(&rci);
|
|
860
|
+
}
|
|
861
|
+
|
|
725
862
|
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
|
|
726
863
|
|
|
727
864
|
{
|
|
@@ -1214,52 +1351,186 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
|
|
1214
1351
|
);
|
|
1215
1352
|
}
|
|
1216
1353
|
|
|
1354
|
+
// number of rows/cols for flash attention shader
|
|
1355
|
+
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
|
1356
|
+
static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
|
1357
|
+
GGML_UNUSED(clamp);
|
|
1358
|
+
|
|
1359
|
+
// small rows, large cols
|
|
1360
|
+
if (small_rows) {
|
|
1361
|
+
return {flash_attention_num_small_rows, 128};
|
|
1362
|
+
}
|
|
1363
|
+
// small cols to reduce register count
|
|
1364
|
+
if (ggml_is_quantized(type) || D == 256) {
|
|
1365
|
+
return {64, 32};
|
|
1366
|
+
}
|
|
1367
|
+
return {64, 64};
|
|
1368
|
+
};
|
|
1369
|
+
|
|
1370
|
+
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id) {
|
|
1371
|
+
// Needs to be kept up to date on shader changes
|
|
1372
|
+
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
|
|
1373
|
+
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
|
1374
|
+
const uint32_t warps = warptile[0] / warptile[10];
|
|
1375
|
+
|
|
1376
|
+
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
|
1377
|
+
const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
|
|
1378
|
+
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
|
1379
|
+
|
|
1380
|
+
return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize;
|
|
1381
|
+
}
|
|
1382
|
+
|
|
1217
1383
|
static void ggml_vk_load_shaders(vk_device& device) {
|
|
1218
1384
|
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
|
1219
1385
|
|
|
1220
1386
|
std::cerr << "ggml_vulkan: Compiling shaders";
|
|
1221
1387
|
|
|
1222
|
-
//
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
uint32_t l_align, m_align, s_align;
|
|
1226
|
-
|
|
1227
|
-
l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
|
|
1228
|
-
m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
|
|
1229
|
-
s_warptile = { std::max(device->subgroup_size, 16u), 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
|
|
1230
|
-
|
|
1231
|
-
l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
|
|
1232
|
-
m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
|
|
1233
|
-
s_warptile_mmq = { std::max(device->subgroup_size, 16u), 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
|
|
1388
|
+
// some shaders have a minimum subgroup size
|
|
1389
|
+
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
|
|
1390
|
+
const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
|
|
1234
1391
|
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1392
|
+
// mulmat
|
|
1393
|
+
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
|
|
1394
|
+
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
|
|
1395
|
+
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
|
|
1396
|
+
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
|
|
1397
|
+
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
|
|
1398
|
+
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
|
|
1399
|
+
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
|
|
1400
|
+
l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
|
|
1238
1401
|
|
|
1239
|
-
l_align
|
|
1240
|
-
|
|
1241
|
-
|
|
1402
|
+
uint32_t l_align, m_align, s_align;
|
|
1403
|
+
if (device->coopmat2) {
|
|
1404
|
+
// spec constants and tile sizes for non-quant matmul/matmul_id
|
|
1405
|
+
l_warptile = { 256, 128, 256, 64 };
|
|
1406
|
+
m_warptile = { 256, 128, 128, 64 };
|
|
1407
|
+
s_warptile = { 128, 32, 16, 64 };
|
|
1408
|
+
l_wg_denoms = {128, 256, 1 };
|
|
1409
|
+
m_wg_denoms = {128, 128, 1 };
|
|
1410
|
+
s_wg_denoms = { 32, 16, 1 };
|
|
1411
|
+
|
|
1412
|
+
// spec constants and tile sizes for quant matmul (non-Qi_K)
|
|
1413
|
+
l_warptile_mmq = { 256, 128, 256, 64 };
|
|
1414
|
+
m_warptile_mmq = { 256, 128, 128, 64 };
|
|
1415
|
+
s_warptile_mmq = { 256, 128, 128, 64 };
|
|
1416
|
+
l_mmq_wg_denoms = { 128, 256, 1 };
|
|
1417
|
+
m_mmq_wg_denoms = { 128, 128, 1 };
|
|
1418
|
+
s_mmq_wg_denoms = { 128, 128, 1 };
|
|
1419
|
+
|
|
1420
|
+
// spec constants and tile sizes for quant matmul (Qi_K)
|
|
1421
|
+
l_warptile_mmq_k = { 256, 128, 512, 16 };
|
|
1422
|
+
m_warptile_mmq_k = { 256, 128, 256, 16 };
|
|
1423
|
+
s_warptile_mmq_k = { 256, 32, 128, 64 };
|
|
1424
|
+
l_mmq_wg_denoms_k = { 128, 512, 1 };
|
|
1425
|
+
m_mmq_wg_denoms_k = { 128, 256, 1 };
|
|
1426
|
+
s_mmq_wg_denoms_k = { 32, 128, 1 };
|
|
1427
|
+
|
|
1428
|
+
// spec constants and tile sizes for quant matmul_id
|
|
1429
|
+
l_warptile_mmqid = { 256, 128, 128, 16 };
|
|
1430
|
+
m_warptile_mmqid = { 256, 128, 64, 16 };
|
|
1431
|
+
s_warptile_mmqid = { 256, 64, 64, 16 };
|
|
1432
|
+
l_mmqid_wg_denoms = { 128, 128, 1 };
|
|
1433
|
+
m_mmqid_wg_denoms = { 128, 64, 1 };
|
|
1434
|
+
s_mmqid_wg_denoms = { 64, 64, 1 };
|
|
1435
|
+
|
|
1436
|
+
l_align = 128;
|
|
1437
|
+
m_align = 64;
|
|
1438
|
+
s_align = 32;
|
|
1439
|
+
} else {
|
|
1440
|
+
// Matrix cores require different warp group sizes
|
|
1441
|
+
const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4;
|
|
1442
|
+
const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4;
|
|
1443
|
+
const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2;
|
|
1444
|
+
const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4;
|
|
1445
|
+
const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2;
|
|
1446
|
+
const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2;
|
|
1447
|
+
const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1;
|
|
1448
|
+
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
|
|
1449
|
+
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
|
|
1450
|
+
|
|
1451
|
+
l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
|
|
1452
|
+
m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
|
|
1453
|
+
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
|
|
1454
|
+
|
|
1455
|
+
l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
|
|
1456
|
+
m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
|
|
1457
|
+
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
|
|
1458
|
+
|
|
1459
|
+
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
|
1460
|
+
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
|
|
1461
|
+
s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
|
|
1462
|
+
l_align = 128;
|
|
1463
|
+
m_align = 64;
|
|
1464
|
+
s_align = 32;
|
|
1465
|
+
|
|
1466
|
+
// Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
|
|
1467
|
+
// and tile sizes, this should handle 16KB, 32KB, and 48KB+.
|
|
1468
|
+
// This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
|
|
1469
|
+
// But the numbers happen to work out for 32KB shared memory size that when using the medium
|
|
1470
|
+
// size there's enough room for everything, and we assert for this.
|
|
1471
|
+
uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
|
|
1472
|
+
if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
|
|
1473
|
+
l_warptile = m_warptile;
|
|
1474
|
+
l_wg_denoms = m_wg_denoms;
|
|
1475
|
+
shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
|
|
1476
|
+
GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
|
|
1477
|
+
}
|
|
1478
|
+
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
|
|
1479
|
+
// assert mul_mat_mat_id shaders will fit.
|
|
1480
|
+
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
|
|
1481
|
+
}
|
|
1482
|
+
|
|
1483
|
+
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
|
|
1484
|
+
if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
|
|
1485
|
+
if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
|
|
1486
|
+
l_warptile_mmq = m_warptile_mmq;
|
|
1487
|
+
l_mmq_wg_denoms = m_mmq_wg_denoms;
|
|
1488
|
+
} else {
|
|
1489
|
+
l_warptile_mmq = s_warptile_mmq;
|
|
1490
|
+
l_mmq_wg_denoms = s_mmq_wg_denoms;
|
|
1491
|
+
}
|
|
1492
|
+
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
|
|
1493
|
+
GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
|
|
1494
|
+
}
|
|
1495
|
+
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
|
|
1496
|
+
// assert mul_mat_mat_id shaders will fit.
|
|
1497
|
+
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
|
|
1498
|
+
}
|
|
1499
|
+
// Disable medium and large matrix multiplication if not enough shared memory is available
|
|
1500
|
+
// Check mmq warptiles as the largest configuration
|
|
1501
|
+
// Throw an error if not enough for any matrix multiplication is available
|
|
1502
|
+
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) {
|
|
1503
|
+
std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
|
|
1504
|
+
throw std::runtime_error("Shared memory size too small for matrix multiplication.");
|
|
1505
|
+
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) {
|
|
1506
|
+
device->mul_mat_m = false;
|
|
1507
|
+
device->mul_mat_l = false;
|
|
1508
|
+
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) {
|
|
1509
|
+
device->mul_mat_l = false;
|
|
1510
|
+
}
|
|
1511
|
+
|
|
1512
|
+
// Disable mul_mat_id if not enough shared memory is available
|
|
1513
|
+
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) {
|
|
1514
|
+
device->mul_mat_id_s = false;
|
|
1515
|
+
device->mul_mat_id_m = false;
|
|
1516
|
+
device->mul_mat_id_l = false;
|
|
1517
|
+
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) {
|
|
1518
|
+
device->mul_mat_id_m = false;
|
|
1519
|
+
device->mul_mat_id_l = false;
|
|
1520
|
+
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) {
|
|
1521
|
+
device->mul_mat_id_l = false;
|
|
1522
|
+
}
|
|
1523
|
+
}
|
|
1242
1524
|
|
|
1243
1525
|
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1244
1526
|
device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1245
1527
|
|
|
1246
1528
|
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1247
|
-
device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1248
|
-
device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1249
|
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1250
|
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1251
|
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1252
|
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1253
|
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1254
|
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1255
|
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1256
|
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1257
|
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1258
|
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1259
|
-
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1260
1529
|
|
|
1261
1530
|
std::vector<std::future<void>> compiles;
|
|
1262
|
-
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
|
|
1531
|
+
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
|
|
1532
|
+
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
|
|
1533
|
+
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
|
|
1263
1534
|
{
|
|
1264
1535
|
// wait until fewer than N compiles are in progress
|
|
1265
1536
|
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
|
|
@@ -1269,144 +1540,368 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1269
1540
|
}
|
|
1270
1541
|
compile_count++;
|
|
1271
1542
|
}
|
|
1272
|
-
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint,
|
|
1543
|
+
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint,
|
|
1544
|
+
parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size));
|
|
1273
1545
|
};
|
|
1274
1546
|
|
|
1275
|
-
|
|
1547
|
+
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
1548
|
+
if (device->coopmat2) {
|
|
1549
|
+
|
|
1550
|
+
auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
|
1551
|
+
return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
|
|
1552
|
+
};
|
|
1553
|
+
|
|
1554
|
+
auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
|
1555
|
+
// For large number of rows, 128 invocations seems to work best.
|
|
1556
|
+
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
|
1557
|
+
// can't use 256 for D==80.
|
|
1558
|
+
uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
|
|
1559
|
+
auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
|
|
1560
|
+
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
|
|
1561
|
+
};
|
|
1562
|
+
|
|
1563
|
+
#define CREATE_FA2(TYPE, NAMELC, D) \
|
|
1564
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
|
|
1565
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
|
|
1566
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
|
|
1567
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
|
|
1568
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
|
|
1569
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
|
|
1570
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
|
|
1571
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
|
|
1572
|
+
|
|
1573
|
+
#define CREATE_FA(TYPE, NAMELC) \
|
|
1574
|
+
CREATE_FA2(TYPE, NAMELC, 64) \
|
|
1575
|
+
CREATE_FA2(TYPE, NAMELC, 80) \
|
|
1576
|
+
CREATE_FA2(TYPE, NAMELC, 96) \
|
|
1577
|
+
CREATE_FA2(TYPE, NAMELC, 112) \
|
|
1578
|
+
CREATE_FA2(TYPE, NAMELC, 128) \
|
|
1579
|
+
CREATE_FA2(TYPE, NAMELC, 256)
|
|
1580
|
+
|
|
1581
|
+
CREATE_FA(GGML_TYPE_F16, f16)
|
|
1582
|
+
CREATE_FA(GGML_TYPE_Q4_0, q4_0)
|
|
1583
|
+
CREATE_FA(GGML_TYPE_Q4_1, q4_1)
|
|
1584
|
+
CREATE_FA(GGML_TYPE_Q5_0, q5_0)
|
|
1585
|
+
CREATE_FA(GGML_TYPE_Q5_1, q5_1)
|
|
1586
|
+
CREATE_FA(GGML_TYPE_Q8_0, q8_0)
|
|
1587
|
+
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
|
1588
|
+
//CREATE_FA(GGML_TYPE_Q2_K, q2_k)
|
|
1589
|
+
//CREATE_FA(GGML_TYPE_Q3_K, q3_k)
|
|
1590
|
+
//CREATE_FA(GGML_TYPE_Q4_K, q4_k)
|
|
1591
|
+
//CREATE_FA(GGML_TYPE_Q5_K, q5_k)
|
|
1592
|
+
//CREATE_FA(GGML_TYPE_Q6_K, q6_k)
|
|
1593
|
+
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
|
|
1594
|
+
#undef CREATE_FA
|
|
1595
|
+
|
|
1276
1596
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
1277
1597
|
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
1278
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ##
|
|
1279
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ##
|
|
1280
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ##
|
|
1281
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ##
|
|
1282
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ##
|
|
1283
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ##
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
CREATE_MM(
|
|
1288
|
-
CREATE_MM(
|
|
1289
|
-
|
|
1290
|
-
CREATE_MM(
|
|
1291
|
-
CREATE_MM(
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
CREATE_MM(
|
|
1297
|
-
CREATE_MM(
|
|
1298
|
-
CREATE_MM(
|
|
1299
|
-
CREATE_MM(
|
|
1300
|
-
CREATE_MM(
|
|
1301
|
-
CREATE_MM(
|
|
1302
|
-
|
|
1303
|
-
CREATE_MM(
|
|
1304
|
-
CREATE_MM(
|
|
1305
|
-
CREATE_MM(
|
|
1306
|
-
|
|
1307
|
-
CREATE_MM(
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[
|
|
1312
|
-
|
|
1313
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[
|
|
1314
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[
|
|
1315
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[
|
|
1316
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[
|
|
1317
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[
|
|
1318
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[
|
|
1598
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
1599
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
1600
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
1601
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
|
1602
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
|
1603
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
|
1604
|
+
|
|
1605
|
+
// Create 2 variants, {f16,f32} accumulator
|
|
1606
|
+
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
1607
|
+
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
1608
|
+
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
1609
|
+
|
|
1610
|
+
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
1611
|
+
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
1612
|
+
|
|
1613
|
+
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
1614
|
+
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
1615
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1616
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1617
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1618
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1619
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1620
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
|
1621
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
|
1622
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
|
1623
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
|
1624
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
|
1625
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1626
|
+
|
|
1627
|
+
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
|
1628
|
+
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
|
1629
|
+
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
|
1630
|
+
|
|
1631
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1632
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1633
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1634
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1635
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1636
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1637
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1638
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1639
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1640
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1641
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1642
|
+
#undef CREATE_MM
|
|
1643
|
+
#undef CREATE_MM2
|
|
1644
|
+
} else
|
|
1645
|
+
#endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
1646
|
+
if (device->coopmat_support) {
|
|
1647
|
+
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
1648
|
+
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1649
|
+
if (device->mul_mat ## ID ## _l) \
|
|
1650
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
|
1651
|
+
if (device->mul_mat ## ID ## _m) \
|
|
1652
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
|
1653
|
+
if (device->mul_mat ## ID ## _s) \
|
|
1654
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
|
1655
|
+
if (device->mul_mat ## ID ## _l) \
|
|
1656
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
|
1657
|
+
if (device->mul_mat ## ID ## _m) \
|
|
1658
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
|
1659
|
+
if (device->mul_mat ## ID ## _s) \
|
|
1660
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
|
1661
|
+
|
|
1662
|
+
// Create 2 variants, {f16,f32} accumulator
|
|
1663
|
+
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1664
|
+
if (device->coopmat_acc_f16_support) { \
|
|
1665
|
+
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1666
|
+
} \
|
|
1667
|
+
if (device->coopmat_acc_f32_support) { \
|
|
1668
|
+
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1669
|
+
} \
|
|
1670
|
+
|
|
1671
|
+
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1672
|
+
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1673
|
+
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1674
|
+
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1675
|
+
|
|
1676
|
+
if (device->coopmat_acc_f16_support) {
|
|
1677
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1678
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1679
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1680
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1681
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1682
|
+
|
|
1683
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1684
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1685
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1686
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1687
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1688
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1689
|
+
} else {
|
|
1690
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1691
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1692
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1693
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1694
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1695
|
+
|
|
1696
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1697
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1698
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1699
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1700
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1701
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1702
|
+
}
|
|
1703
|
+
|
|
1704
|
+
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
|
1705
|
+
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
|
|
1706
|
+
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1707
|
+
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1708
|
+
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1709
|
+
|
|
1710
|
+
if (device->coopmat_acc_f16_support) {
|
|
1711
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1712
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1713
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1714
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1715
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1716
|
+
|
|
1717
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1718
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1719
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1720
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1721
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1722
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1723
|
+
} else {
|
|
1724
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1725
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1726
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1727
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1728
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1729
|
+
|
|
1730
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1731
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1732
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1733
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1734
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1735
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1736
|
+
}
|
|
1737
|
+
}
|
|
1738
|
+
#undef CREATE_MM2
|
|
1739
|
+
#undef CREATE_MM
|
|
1740
|
+
} else if (device->fp16) {
|
|
1741
|
+
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
1742
|
+
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1743
|
+
if (device->mul_mat ## ID ## _l) \
|
|
1744
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
1745
|
+
if (device->mul_mat ## ID ## _m) \
|
|
1746
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
1747
|
+
if (device->mul_mat ## ID ## _s) \
|
|
1748
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
1749
|
+
if (device->mul_mat ## ID ## _l) \
|
|
1750
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
|
1751
|
+
if (device->mul_mat ## ID ## _m) \
|
|
1752
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
|
1753
|
+
if (device->mul_mat ## ID ## _s) \
|
|
1754
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
|
1755
|
+
|
|
1756
|
+
// Create 2 variants, {f16,f32} accumulator
|
|
1757
|
+
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1758
|
+
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1759
|
+
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1760
|
+
|
|
1761
|
+
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1762
|
+
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1763
|
+
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1764
|
+
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1765
|
+
|
|
1766
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1767
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1768
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1769
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1770
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1771
|
+
|
|
1772
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1773
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1774
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1775
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1776
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1777
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1778
|
+
|
|
1779
|
+
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
|
1780
|
+
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
|
|
1781
|
+
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1782
|
+
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1783
|
+
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1784
|
+
|
|
1785
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1786
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1787
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1788
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1789
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1790
|
+
|
|
1791
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1792
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1793
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1794
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1795
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1796
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1797
|
+
}
|
|
1798
|
+
#undef CREATE_MM2
|
|
1319
1799
|
#undef CREATE_MM
|
|
1320
1800
|
} else {
|
|
1321
1801
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
1322
|
-
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
CREATE_MM(
|
|
1337
|
-
CREATE_MM(
|
|
1338
|
-
CREATE_MM(
|
|
1339
|
-
CREATE_MM(
|
|
1340
|
-
|
|
1341
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1342
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1343
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1344
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1345
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
CREATE_MM(
|
|
1349
|
-
CREATE_MM(
|
|
1350
|
-
CREATE_MM(
|
|
1351
|
-
|
|
1352
|
-
CREATE_MM(
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1802
|
+
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1803
|
+
if (device->mul_mat ## ID ## _l) \
|
|
1804
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
1805
|
+
if (device->mul_mat ## ID ## _m) \
|
|
1806
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
1807
|
+
if (device->mul_mat ## ID ## _s) \
|
|
1808
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
1809
|
+
if (device->mul_mat ## ID ## _l) \
|
|
1810
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
|
1811
|
+
if (device->mul_mat ## ID ## _m) \
|
|
1812
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
|
1813
|
+
if (device->mul_mat ## ID ## _s) \
|
|
1814
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
|
1815
|
+
|
|
1816
|
+
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1817
|
+
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1818
|
+
CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1819
|
+
CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1820
|
+
|
|
1821
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1822
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1823
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1824
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1825
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1826
|
+
|
|
1827
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1828
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1829
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1830
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1831
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1832
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1833
|
+
|
|
1834
|
+
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
|
1835
|
+
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
|
|
1836
|
+
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1837
|
+
CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1838
|
+
CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1839
|
+
|
|
1840
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1841
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1842
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1843
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1844
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1845
|
+
|
|
1846
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1847
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1848
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1849
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1850
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1851
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1852
|
+
}
|
|
1364
1853
|
#undef CREATE_MM
|
|
1365
1854
|
}
|
|
1366
1855
|
|
|
1367
1856
|
// mul mat vec
|
|
1368
|
-
|
|
1857
|
+
|
|
1858
|
+
// AMD GCN and Intel graphics cards perform best when the number of rows per shader is doubled
|
|
1859
|
+
uint32_t rm = 1;
|
|
1860
|
+
if ((device->vendor_id == VK_VENDOR_ID_AMD && device->subgroup_min_size == 64 && device->subgroup_max_size == 64) || device->vendor_id == VK_VENDOR_ID_INTEL)
|
|
1861
|
+
rm = 2;
|
|
1862
|
+
|
|
1863
|
+
// computing additional rows per workgroup is a benefit for Q4_0 -> Q5_1, but not for Q8_0.
|
|
1369
1864
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1370
1865
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1371
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1372
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1373
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1374
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1375
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
|
|
1376
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {
|
|
1377
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {
|
|
1378
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {
|
|
1379
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {
|
|
1380
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {
|
|
1381
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {
|
|
1866
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
|
1867
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
|
1868
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
|
1869
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
|
1870
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
|
|
1871
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1872
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1873
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1874
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1875
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1876
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
|
1382
1877
|
|
|
1383
1878
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1384
1879
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1385
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1386
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1387
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1388
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1389
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
|
|
1390
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {
|
|
1391
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {
|
|
1392
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {
|
|
1393
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {
|
|
1394
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {
|
|
1395
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {
|
|
1880
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
|
1881
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
|
1882
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
|
1883
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
|
1884
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
|
|
1885
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1886
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1887
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1888
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1889
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1890
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
|
1396
1891
|
|
|
1397
1892
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1398
1893
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1399
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1400
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1401
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1402
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
1403
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
|
|
1404
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {
|
|
1405
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {
|
|
1406
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {
|
|
1407
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {
|
|
1408
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {
|
|
1409
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {
|
|
1894
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
|
1895
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
|
1896
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
|
1897
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
|
|
1898
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
|
|
1899
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1900
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1901
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1902
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1903
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
|
|
1904
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
|
|
1410
1905
|
|
|
1411
1906
|
// dequant shaders
|
|
1412
1907
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
@@ -1441,7 +1936,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1441
1936
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
1442
1937
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
1443
1938
|
|
|
1444
|
-
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
|
|
1939
|
+
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
|
1445
1940
|
|
|
1446
1941
|
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
|
1447
1942
|
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
|
@@ -1497,26 +1992,39 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1497
1992
|
|
|
1498
1993
|
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
|
|
1499
1994
|
|
|
1500
|
-
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
|
|
1501
|
-
ggml_vk_create_pipeline(device, device->
|
|
1995
|
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
1996
|
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
|
1997
|
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
1998
|
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
|
1502
1999
|
|
|
1503
2000
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
1504
|
-
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
1505
|
-
|
|
1506
2001
|
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
1507
|
-
|
|
2002
|
+
|
|
2003
|
+
if (device->float_controls_rte_fp16) {
|
|
2004
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2005
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2006
|
+
} else {
|
|
2007
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2008
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2009
|
+
}
|
|
1508
2010
|
|
|
1509
2011
|
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
|
|
1510
2012
|
|
|
1511
2013
|
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
1512
2014
|
|
|
1513
2015
|
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
|
|
1514
|
-
|
|
2016
|
+
if (device->float_controls_rte_fp16) {
|
|
2017
|
+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
|
|
2018
|
+
} else {
|
|
2019
|
+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
|
|
2020
|
+
}
|
|
1515
2021
|
|
|
1516
2022
|
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
|
1517
2023
|
|
|
1518
2024
|
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
|
|
1519
2025
|
|
|
2026
|
+
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
|
2027
|
+
|
|
1520
2028
|
for (auto &c : compiles) {
|
|
1521
2029
|
c.wait();
|
|
1522
2030
|
}
|
|
@@ -1550,12 +2058,40 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
1550
2058
|
device->physical_device = physical_devices[dev_num];
|
|
1551
2059
|
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
|
1552
2060
|
|
|
2061
|
+
bool fp16_storage = false;
|
|
2062
|
+
bool fp16_compute = false;
|
|
1553
2063
|
bool maintenance4_support = false;
|
|
2064
|
+
bool sm_builtins = false;
|
|
2065
|
+
bool amd_shader_core_properties2 = false;
|
|
2066
|
+
bool pipeline_robustness = false;
|
|
2067
|
+
bool coopmat2_support = false;
|
|
2068
|
+
device->coopmat_support = false;
|
|
1554
2069
|
|
|
1555
2070
|
// Check if maintenance4 is supported
|
|
1556
2071
|
for (const auto& properties : ext_props) {
|
|
1557
2072
|
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
|
1558
2073
|
maintenance4_support = true;
|
|
2074
|
+
} else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
|
2075
|
+
fp16_storage = true;
|
|
2076
|
+
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
|
|
2077
|
+
fp16_compute = true;
|
|
2078
|
+
} else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
|
|
2079
|
+
sm_builtins = true;
|
|
2080
|
+
} else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) {
|
|
2081
|
+
amd_shader_core_properties2 = true;
|
|
2082
|
+
} else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
|
|
2083
|
+
pipeline_robustness = true;
|
|
2084
|
+
} else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
|
|
2085
|
+
device->subgroup_size_control = true;
|
|
2086
|
+
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
|
|
2087
|
+
!getenv("GGML_VK_DISABLE_COOPMAT")) {
|
|
2088
|
+
device->coopmat_support = true;
|
|
2089
|
+
device->coopmat_m = 0;
|
|
2090
|
+
device->coopmat_n = 0;
|
|
2091
|
+
device->coopmat_k = 0;
|
|
2092
|
+
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
|
2093
|
+
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
|
2094
|
+
coopmat2_support = true;
|
|
1559
2095
|
}
|
|
1560
2096
|
}
|
|
1561
2097
|
|
|
@@ -1563,18 +2099,51 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
1563
2099
|
vk::PhysicalDeviceMaintenance3Properties props3;
|
|
1564
2100
|
vk::PhysicalDeviceMaintenance4Properties props4;
|
|
1565
2101
|
vk::PhysicalDeviceSubgroupProperties subgroup_props;
|
|
2102
|
+
vk::PhysicalDeviceDriverProperties driver_props;
|
|
2103
|
+
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
|
|
2104
|
+
vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
|
|
2105
|
+
vk::PhysicalDeviceVulkan12Properties vk12_props;
|
|
2106
|
+
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
|
2107
|
+
|
|
1566
2108
|
props2.pNext = &props3;
|
|
1567
2109
|
props3.pNext = &subgroup_props;
|
|
2110
|
+
subgroup_props.pNext = &driver_props;
|
|
2111
|
+
driver_props.pNext = &vk12_props;
|
|
2112
|
+
|
|
2113
|
+
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
|
|
2114
|
+
|
|
1568
2115
|
if (maintenance4_support) {
|
|
1569
|
-
|
|
2116
|
+
last_struct->pNext = (VkBaseOutStructure *)&props4;
|
|
2117
|
+
last_struct = (VkBaseOutStructure *)&props4;
|
|
2118
|
+
}
|
|
2119
|
+
if (sm_builtins) {
|
|
2120
|
+
last_struct->pNext = (VkBaseOutStructure *)&sm_props;
|
|
2121
|
+
last_struct = (VkBaseOutStructure *)&sm_props;
|
|
2122
|
+
}
|
|
2123
|
+
if (amd_shader_core_properties2) {
|
|
2124
|
+
last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
|
|
2125
|
+
last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
|
|
1570
2126
|
}
|
|
2127
|
+
if (device->subgroup_size_control) {
|
|
2128
|
+
last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props;
|
|
2129
|
+
last_struct = (VkBaseOutStructure *)&subgroup_size_control_props;
|
|
2130
|
+
}
|
|
2131
|
+
|
|
2132
|
+
#if defined(VK_NV_cooperative_matrix2)
|
|
2133
|
+
vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
|
|
2134
|
+
if (coopmat2_support) {
|
|
2135
|
+
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props;
|
|
2136
|
+
last_struct = (VkBaseOutStructure *)&coopmat2_props;
|
|
2137
|
+
}
|
|
2138
|
+
#endif
|
|
2139
|
+
|
|
1571
2140
|
device->physical_device.getProperties2(&props2);
|
|
1572
2141
|
device->properties = props2.properties;
|
|
1573
2142
|
|
|
1574
2143
|
const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
|
|
1575
2144
|
|
|
1576
2145
|
if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
|
|
1577
|
-
device->max_memory_allocation_size = std::
|
|
2146
|
+
device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
|
|
1578
2147
|
} else if (maintenance4_support) {
|
|
1579
2148
|
device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
|
|
1580
2149
|
} else {
|
|
@@ -1584,23 +2153,25 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
1584
2153
|
device->vendor_id = device->properties.vendorID;
|
|
1585
2154
|
device->subgroup_size = subgroup_props.subgroupSize;
|
|
1586
2155
|
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
fp16_storage = true;
|
|
1594
|
-
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
|
|
1595
|
-
fp16_compute = true;
|
|
1596
|
-
}
|
|
2156
|
+
if (sm_builtins) {
|
|
2157
|
+
device->shader_core_count = sm_props.shaderSMCount;
|
|
2158
|
+
} else if (amd_shader_core_properties2) {
|
|
2159
|
+
device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
|
|
2160
|
+
} else {
|
|
2161
|
+
device->shader_core_count = 0;
|
|
1597
2162
|
}
|
|
2163
|
+
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
|
|
1598
2164
|
|
|
1599
|
-
const
|
|
1600
|
-
const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
|
|
2165
|
+
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
|
|
1601
2166
|
|
|
1602
2167
|
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
|
1603
2168
|
|
|
2169
|
+
if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
|
|
2170
|
+
// Intel drivers don't support coopmat properly yet
|
|
2171
|
+
// Only RADV supports coopmat properly on AMD
|
|
2172
|
+
device->coopmat_support = false;
|
|
2173
|
+
}
|
|
2174
|
+
|
|
1604
2175
|
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
|
1605
2176
|
|
|
1606
2177
|
// Try to find a non-graphics compute queue and transfer-focused queues
|
|
@@ -1638,10 +2209,149 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
1638
2209
|
vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
|
|
1639
2210
|
vk11_features.pNext = &vk12_features;
|
|
1640
2211
|
|
|
2212
|
+
last_struct = (VkBaseOutStructure *)&vk12_features;
|
|
2213
|
+
|
|
2214
|
+
VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
|
|
2215
|
+
pl_robustness_features.pNext = nullptr;
|
|
2216
|
+
pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
|
|
2217
|
+
pl_robustness_features.pipelineRobustness = VK_FALSE;
|
|
2218
|
+
|
|
2219
|
+
if (pipeline_robustness) {
|
|
2220
|
+
last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features;
|
|
2221
|
+
last_struct = (VkBaseOutStructure *)&pl_robustness_features;
|
|
2222
|
+
device_extensions.push_back("VK_EXT_pipeline_robustness");
|
|
2223
|
+
}
|
|
2224
|
+
|
|
2225
|
+
VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
|
|
2226
|
+
subgroup_size_control_features.pNext = nullptr;
|
|
2227
|
+
subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
|
|
2228
|
+
subgroup_size_control_features.computeFullSubgroups = false;
|
|
2229
|
+
subgroup_size_control_features.subgroupSizeControl = false;
|
|
2230
|
+
|
|
2231
|
+
if (device->subgroup_size_control) {
|
|
2232
|
+
last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features;
|
|
2233
|
+
last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
|
|
2234
|
+
}
|
|
2235
|
+
|
|
2236
|
+
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
|
|
2237
|
+
coopmat_features.pNext = nullptr;
|
|
2238
|
+
coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
|
|
2239
|
+
coopmat_features.cooperativeMatrix = VK_FALSE;
|
|
2240
|
+
|
|
2241
|
+
if (device->coopmat_support) {
|
|
2242
|
+
last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
|
|
2243
|
+
last_struct = (VkBaseOutStructure *)&coopmat_features;
|
|
2244
|
+
}
|
|
2245
|
+
|
|
2246
|
+
#if defined(VK_NV_cooperative_matrix2)
|
|
2247
|
+
VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
|
|
2248
|
+
coopmat2_features.pNext = nullptr;
|
|
2249
|
+
coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
|
|
2250
|
+
if (coopmat2_support) {
|
|
2251
|
+
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;
|
|
2252
|
+
last_struct = (VkBaseOutStructure *)&coopmat2_features;
|
|
2253
|
+
device_extensions.push_back("VK_NV_cooperative_matrix2");
|
|
2254
|
+
}
|
|
2255
|
+
#endif
|
|
2256
|
+
|
|
1641
2257
|
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
|
1642
2258
|
|
|
1643
2259
|
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
|
1644
2260
|
|
|
2261
|
+
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
|
|
2262
|
+
|
|
2263
|
+
if (device->subgroup_size_control) {
|
|
2264
|
+
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
|
|
2265
|
+
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
|
|
2266
|
+
}
|
|
2267
|
+
|
|
2268
|
+
device->subgroup_size_control = device->subgroup_size_control &&
|
|
2269
|
+
(subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
2270
|
+
subgroup_size_control_features.subgroupSizeControl;
|
|
2271
|
+
|
|
2272
|
+
if (device->subgroup_size_control) {
|
|
2273
|
+
device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
|
|
2274
|
+
device_extensions.push_back("VK_EXT_subgroup_size_control");
|
|
2275
|
+
}
|
|
2276
|
+
|
|
2277
|
+
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
|
|
2278
|
+
|
|
2279
|
+
if (coopmat2_support) {
|
|
2280
|
+
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
2281
|
+
if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
|
|
2282
|
+
coopmat2_features.cooperativeMatrixFlexibleDimensions &&
|
|
2283
|
+
coopmat2_features.cooperativeMatrixReductions &&
|
|
2284
|
+
coopmat2_features.cooperativeMatrixConversions &&
|
|
2285
|
+
coopmat2_features.cooperativeMatrixPerElementOperations &&
|
|
2286
|
+
coopmat2_features.cooperativeMatrixTensorAddressing &&
|
|
2287
|
+
coopmat2_features.cooperativeMatrixBlockLoads &&
|
|
2288
|
+
vk12_features.bufferDeviceAddress) {
|
|
2289
|
+
|
|
2290
|
+
std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexible_dimensions;
|
|
2291
|
+
uint32_t count = 0;
|
|
2292
|
+
|
|
2293
|
+
PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV
|
|
2294
|
+
_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV =
|
|
2295
|
+
(PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV)
|
|
2296
|
+
vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV");
|
|
2297
|
+
|
|
2298
|
+
_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr);
|
|
2299
|
+
|
|
2300
|
+
VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {};
|
|
2301
|
+
empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV;
|
|
2302
|
+
flexible_dimensions.resize(count, empty_prop);
|
|
2303
|
+
|
|
2304
|
+
_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data());
|
|
2305
|
+
|
|
2306
|
+
bool found_fp16_128 = false,
|
|
2307
|
+
found_fp16_256 = false,
|
|
2308
|
+
found_fp32_128 = false,
|
|
2309
|
+
found_fp32_256 = false;
|
|
2310
|
+
// need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
|
|
2311
|
+
// with 32x16x16 and 256 with 32x32x16.
|
|
2312
|
+
for (auto &prop : flexible_dimensions) {
|
|
2313
|
+
if (prop.saturatingAccumulation == VK_FALSE &&
|
|
2314
|
+
prop.scope == VK_SCOPE_WORKGROUP_KHR &&
|
|
2315
|
+
prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
|
2316
|
+
prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
|
2317
|
+
|
|
2318
|
+
if (prop.workgroupInvocations == 128 &&
|
|
2319
|
+
prop.MGranularity <= 32 &&
|
|
2320
|
+
prop.NGranularity <= 16 &&
|
|
2321
|
+
prop.KGranularity <= 16) {
|
|
2322
|
+
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
|
2323
|
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
|
2324
|
+
found_fp16_128 = true;
|
|
2325
|
+
}
|
|
2326
|
+
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
|
2327
|
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
|
|
2328
|
+
found_fp32_128 = true;
|
|
2329
|
+
}
|
|
2330
|
+
}
|
|
2331
|
+
if (prop.workgroupInvocations == 256 &&
|
|
2332
|
+
prop.MGranularity <= 32 &&
|
|
2333
|
+
prop.NGranularity <= 32 &&
|
|
2334
|
+
prop.KGranularity <= 16) {
|
|
2335
|
+
if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
|
|
2336
|
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
|
|
2337
|
+
found_fp16_256 = true;
|
|
2338
|
+
}
|
|
2339
|
+
if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
|
2340
|
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
|
|
2341
|
+
found_fp32_256 = true;
|
|
2342
|
+
}
|
|
2343
|
+
}
|
|
2344
|
+
}
|
|
2345
|
+
}
|
|
2346
|
+
if (found_fp16_128 && found_fp16_256 &&
|
|
2347
|
+
found_fp32_128 && found_fp32_256 &&
|
|
2348
|
+
coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
|
|
2349
|
+
device->coopmat2 = true;
|
|
2350
|
+
}
|
|
2351
|
+
}
|
|
2352
|
+
#endif
|
|
2353
|
+
}
|
|
2354
|
+
|
|
1645
2355
|
if (!vk11_features.storageBuffer16BitAccess) {
|
|
1646
2356
|
std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
|
|
1647
2357
|
throw std::runtime_error("Unsupported device");
|
|
@@ -1656,6 +2366,74 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
1656
2366
|
if (device->fp16) {
|
|
1657
2367
|
device_extensions.push_back("VK_KHR_shader_float16_int8");
|
|
1658
2368
|
}
|
|
2369
|
+
|
|
2370
|
+
if (device->coopmat_support) {
|
|
2371
|
+
// Query supported shapes
|
|
2372
|
+
std::vector<VkCooperativeMatrixPropertiesKHR> cm_props;
|
|
2373
|
+
|
|
2374
|
+
PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR =
|
|
2375
|
+
(PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR");
|
|
2376
|
+
|
|
2377
|
+
uint32_t cm_props_num;
|
|
2378
|
+
|
|
2379
|
+
pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr);
|
|
2380
|
+
|
|
2381
|
+
cm_props.resize(cm_props_num);
|
|
2382
|
+
|
|
2383
|
+
for (auto& prop : cm_props) {
|
|
2384
|
+
prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
|
|
2385
|
+
}
|
|
2386
|
+
|
|
2387
|
+
pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data());
|
|
2388
|
+
|
|
2389
|
+
VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size());
|
|
2390
|
+
|
|
2391
|
+
for (auto& prop : cm_props) {
|
|
2392
|
+
VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope));
|
|
2393
|
+
|
|
2394
|
+
if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
|
|
2395
|
+
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
|
|
2396
|
+
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
|
|
2397
|
+
) {
|
|
2398
|
+
if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
|
|
2399
|
+
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) {
|
|
2400
|
+
// coopmat sizes not set yet
|
|
2401
|
+
if (device->coopmat_m == 0) {
|
|
2402
|
+
device->coopmat_acc_f32_support = true;
|
|
2403
|
+
device->coopmat_m = prop.MSize;
|
|
2404
|
+
device->coopmat_n = prop.NSize;
|
|
2405
|
+
device->coopmat_k = prop.KSize;
|
|
2406
|
+
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
|
|
2407
|
+
// Only enable if shape is identical
|
|
2408
|
+
device->coopmat_acc_f32_support = true;
|
|
2409
|
+
}
|
|
2410
|
+
} else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
|
|
2411
|
+
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
|
|
2412
|
+
// coopmat sizes not set yet
|
|
2413
|
+
if (device->coopmat_m == 0) {
|
|
2414
|
+
device->coopmat_acc_f16_support = true;
|
|
2415
|
+
device->coopmat_m = prop.MSize;
|
|
2416
|
+
device->coopmat_n = prop.NSize;
|
|
2417
|
+
device->coopmat_k = prop.KSize;
|
|
2418
|
+
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
|
|
2419
|
+
// Only enable if shape is identical
|
|
2420
|
+
device->coopmat_acc_f16_support = true;
|
|
2421
|
+
}
|
|
2422
|
+
}
|
|
2423
|
+
}
|
|
2424
|
+
}
|
|
2425
|
+
|
|
2426
|
+
if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
|
|
2427
|
+
// No suitable matmul mode found
|
|
2428
|
+
GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
|
|
2429
|
+
device->coopmat_support = false;
|
|
2430
|
+
}
|
|
2431
|
+
}
|
|
2432
|
+
|
|
2433
|
+
if (device->coopmat_support) {
|
|
2434
|
+
device_extensions.push_back("VK_KHR_cooperative_matrix");
|
|
2435
|
+
}
|
|
2436
|
+
|
|
1659
2437
|
device->name = GGML_VK_NAME + std::to_string(idx);
|
|
1660
2438
|
|
|
1661
2439
|
device_create_info = {
|
|
@@ -1671,6 +2449,37 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
1671
2449
|
ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
|
|
1672
2450
|
|
|
1673
2451
|
// Shaders
|
|
2452
|
+
// Disable matmul tile sizes early if performance low or not supported
|
|
2453
|
+
switch (device->vendor_id) {
|
|
2454
|
+
#ifndef GGML_VULKAN_RUN_TESTS
|
|
2455
|
+
case VK_VENDOR_ID_AMD:
|
|
2456
|
+
case VK_VENDOR_ID_INTEL:
|
|
2457
|
+
device->mul_mat_l = false;
|
|
2458
|
+
device->mul_mat_m = true;
|
|
2459
|
+
device->mul_mat_s = true;
|
|
2460
|
+
device->mul_mat_id_l = false;
|
|
2461
|
+
device->mul_mat_id_m = true;
|
|
2462
|
+
device->mul_mat_id_s = true;
|
|
2463
|
+
break;
|
|
2464
|
+
case VK_VENDOR_ID_APPLE:
|
|
2465
|
+
device->mul_mat_l = false;
|
|
2466
|
+
device->mul_mat_m = true;
|
|
2467
|
+
device->mul_mat_s = false;
|
|
2468
|
+
device->mul_mat_id_l = false;
|
|
2469
|
+
device->mul_mat_id_m = true;
|
|
2470
|
+
device->mul_mat_id_s = false;
|
|
2471
|
+
break;
|
|
2472
|
+
#endif
|
|
2473
|
+
default:
|
|
2474
|
+
device->mul_mat_l = true;
|
|
2475
|
+
device->mul_mat_m = true;
|
|
2476
|
+
device->mul_mat_s = true;
|
|
2477
|
+
device->mul_mat_id_l = true;
|
|
2478
|
+
device->mul_mat_id_m = true;
|
|
2479
|
+
device->mul_mat_id_s = true;
|
|
2480
|
+
break;
|
|
2481
|
+
}
|
|
2482
|
+
|
|
1674
2483
|
ggml_vk_load_shaders(device);
|
|
1675
2484
|
|
|
1676
2485
|
if (!device->single_queue) {
|
|
@@ -1728,15 +2537,31 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
1728
2537
|
|
|
1729
2538
|
bool fp16_storage = false;
|
|
1730
2539
|
bool fp16_compute = false;
|
|
2540
|
+
bool coopmat_support = false;
|
|
2541
|
+
bool coopmat2_support = false;
|
|
1731
2542
|
|
|
1732
2543
|
for (auto properties : ext_props) {
|
|
1733
2544
|
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
|
1734
2545
|
fp16_storage = true;
|
|
1735
2546
|
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
|
|
1736
2547
|
fp16_compute = true;
|
|
2548
|
+
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
|
|
2549
|
+
!getenv("GGML_VK_DISABLE_COOPMAT")) {
|
|
2550
|
+
coopmat_support = true;
|
|
2551
|
+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
2552
|
+
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
|
2553
|
+
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
|
2554
|
+
coopmat2_support = true;
|
|
2555
|
+
#endif
|
|
1737
2556
|
}
|
|
1738
2557
|
}
|
|
1739
2558
|
|
|
2559
|
+
if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
|
|
2560
|
+
// Intel drivers don't support coopmat properly yet
|
|
2561
|
+
// Only RADV supports coopmat properly on AMD
|
|
2562
|
+
coopmat_support = false;
|
|
2563
|
+
}
|
|
2564
|
+
|
|
1740
2565
|
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
|
|
1741
2566
|
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
|
|
1742
2567
|
|
|
@@ -1759,16 +2584,33 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
1759
2584
|
vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
|
|
1760
2585
|
vk11_features.pNext = &vk12_features;
|
|
1761
2586
|
|
|
2587
|
+
// Pointer to the last chain element
|
|
2588
|
+
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
|
|
2589
|
+
|
|
2590
|
+
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
|
|
2591
|
+
coopmat_features.pNext = nullptr;
|
|
2592
|
+
coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
|
|
2593
|
+
coopmat_features.cooperativeMatrix = VK_FALSE;
|
|
2594
|
+
|
|
2595
|
+
if (coopmat_support) {
|
|
2596
|
+
last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
|
|
2597
|
+
last_struct = (VkBaseOutStructure *)&coopmat_features;
|
|
2598
|
+
}
|
|
2599
|
+
|
|
1762
2600
|
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
|
1763
2601
|
|
|
1764
2602
|
fp16 = fp16 && vk12_features.shaderFloat16;
|
|
1765
2603
|
|
|
2604
|
+
coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
|
|
2605
|
+
|
|
2606
|
+
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
|
|
2607
|
+
|
|
1766
2608
|
std::string device_name = props2.properties.deviceName.data();
|
|
1767
|
-
GGML_LOG_DEBUG("ggml_vulkan: %
|
|
1768
|
-
idx, device_name.c_str(), driver_props.driverName, uma, fp16, subgroup_size);
|
|
2609
|
+
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n",
|
|
2610
|
+
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str());
|
|
1769
2611
|
|
|
1770
2612
|
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
|
1771
|
-
|
|
2613
|
+
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
|
|
1772
2614
|
}
|
|
1773
2615
|
}
|
|
1774
2616
|
|
|
@@ -1937,8 +2779,7 @@ void ggml_vk_instance_init() {
|
|
|
1937
2779
|
vk_instance.device_indices.push_back(0);
|
|
1938
2780
|
}
|
|
1939
2781
|
}
|
|
1940
|
-
GGML_LOG_DEBUG("ggml_vulkan: Found %
|
|
1941
|
-
|
|
2782
|
+
GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
|
|
1942
2783
|
|
|
1943
2784
|
for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
|
|
1944
2785
|
ggml_vk_print_gpu_info(i);
|
|
@@ -1994,7 +2835,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
|
|
1994
2835
|
return ctx->device->pipeline_dequant[type];
|
|
1995
2836
|
}
|
|
1996
2837
|
|
|
1997
|
-
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
|
|
2838
|
+
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
|
|
1998
2839
|
VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
|
1999
2840
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
|
2000
2841
|
return ctx->device->pipeline_matmul_f32;
|
|
@@ -2002,14 +2843,23 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
|
2002
2843
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
|
2003
2844
|
return ctx->device->pipeline_matmul_f32_f16;
|
|
2004
2845
|
}
|
|
2005
|
-
if (
|
|
2006
|
-
|
|
2007
|
-
|
|
2008
|
-
|
|
2009
|
-
|
|
2846
|
+
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
|
2847
|
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
|
2848
|
+
return ctx->device->pipeline_matmul_f16_f32.f16acc;
|
|
2849
|
+
}
|
|
2850
|
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
|
2851
|
+
return ctx->device->pipeline_matmul_f16.f16acc;
|
|
2852
|
+
}
|
|
2853
|
+
} else {
|
|
2854
|
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
|
2855
|
+
return ctx->device->pipeline_matmul_f16_f32.f32acc;
|
|
2856
|
+
}
|
|
2857
|
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
|
2858
|
+
return ctx->device->pipeline_matmul_f16.f32acc;
|
|
2859
|
+
}
|
|
2010
2860
|
}
|
|
2011
2861
|
|
|
2012
|
-
if (src1_type != GGML_TYPE_F32) {
|
|
2862
|
+
if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
|
|
2013
2863
|
return nullptr;
|
|
2014
2864
|
}
|
|
2015
2865
|
|
|
@@ -2030,7 +2880,11 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
|
2030
2880
|
return nullptr;
|
|
2031
2881
|
}
|
|
2032
2882
|
|
|
2033
|
-
|
|
2883
|
+
if (ctx->device->coopmat2) {
|
|
2884
|
+
assert(src1_type == GGML_TYPE_F16);
|
|
2885
|
+
return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
|
|
2886
|
+
}
|
|
2887
|
+
return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
|
2034
2888
|
}
|
|
2035
2889
|
|
|
2036
2890
|
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
|
|
@@ -2059,16 +2913,25 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|
|
2059
2913
|
return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
|
|
2060
2914
|
}
|
|
2061
2915
|
|
|
2062
|
-
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
|
|
2916
|
+
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
|
|
2063
2917
|
VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
|
|
2064
2918
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
|
2065
2919
|
return ctx->device->pipeline_matmul_id_f32;
|
|
2066
2920
|
}
|
|
2067
|
-
if (
|
|
2068
|
-
|
|
2069
|
-
|
|
2070
|
-
|
|
2071
|
-
|
|
2921
|
+
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
|
2922
|
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
|
2923
|
+
return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
|
|
2924
|
+
}
|
|
2925
|
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
|
2926
|
+
return ctx->device->pipeline_matmul_id_f16.f16acc;
|
|
2927
|
+
}
|
|
2928
|
+
} else {
|
|
2929
|
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
|
2930
|
+
return ctx->device->pipeline_matmul_id_f16_f32.f32acc;
|
|
2931
|
+
}
|
|
2932
|
+
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
|
2933
|
+
return ctx->device->pipeline_matmul_id_f16.f32acc;
|
|
2934
|
+
}
|
|
2072
2935
|
}
|
|
2073
2936
|
|
|
2074
2937
|
GGML_ASSERT(src1_type == GGML_TYPE_F32);
|
|
@@ -2090,7 +2953,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|
|
2090
2953
|
return nullptr;
|
|
2091
2954
|
}
|
|
2092
2955
|
|
|
2093
|
-
return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
|
|
2956
|
+
return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
|
|
2094
2957
|
}
|
|
2095
2958
|
|
|
2096
2959
|
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
|
|
@@ -2659,55 +3522,44 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
|
|
|
2659
3522
|
dst->device->device.resetFences({ dst->device->fence });
|
|
2660
3523
|
}
|
|
2661
3524
|
|
|
2662
|
-
static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
|
|
3525
|
+
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
|
|
2663
3526
|
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
|
|
2664
|
-
// if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
|
|
2665
|
-
// return 4;
|
|
2666
|
-
// }
|
|
2667
3527
|
|
|
2668
|
-
|
|
2669
|
-
|
|
2670
|
-
|
|
2671
|
-
|
|
2672
|
-
|
|
2673
|
-
|
|
2674
|
-
|
|
2675
|
-
|
|
3528
|
+
uint32_t split_k = 1;
|
|
3529
|
+
if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
|
|
3530
|
+
// If k is 'large' and the SMs will fill less than halfway, use split_k.
|
|
3531
|
+
uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
|
|
3532
|
+
uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
|
|
3533
|
+
if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
|
|
3534
|
+
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
|
|
3535
|
+
// Clamp to 2 or 4
|
|
3536
|
+
split_k = std::min(split_k, 4u);
|
|
3537
|
+
if (split_k == 3) {
|
|
3538
|
+
split_k = 2;
|
|
3539
|
+
}
|
|
3540
|
+
}
|
|
2676
3541
|
}
|
|
2677
|
-
return aligned ? mmp->a_m : mmp->m;
|
|
2678
|
-
|
|
2679
|
-
GGML_UNUSED(ctx);
|
|
2680
|
-
}
|
|
2681
3542
|
|
|
2682
|
-
|
|
2683
|
-
return aligned ? mmp->a_m : mmp->m;
|
|
2684
|
-
|
|
2685
|
-
GGML_UNUSED(ctx);
|
|
2686
|
-
}
|
|
2687
|
-
|
|
2688
|
-
static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
|
|
2689
|
-
return aligned ? mmp->a_s : mmp->s;
|
|
2690
|
-
|
|
2691
|
-
GGML_UNUSED(ctx);
|
|
3543
|
+
return split_k;
|
|
2692
3544
|
}
|
|
2693
3545
|
|
|
2694
3546
|
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
|
|
2695
3547
|
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
|
|
2696
|
-
|
|
2697
|
-
|
|
2698
|
-
|
|
2699
|
-
|
|
2700
|
-
|
|
2701
|
-
|
|
2702
|
-
|
|
2703
|
-
|
|
2704
|
-
|
|
3548
|
+
|
|
3549
|
+
if (ctx->device->coopmat2) {
|
|
3550
|
+
if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
|
|
3551
|
+
return aligned ? mmp->a_l : mmp->l;
|
|
3552
|
+
}
|
|
3553
|
+
if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) {
|
|
3554
|
+
return aligned ? mmp->a_m : mmp->m;
|
|
3555
|
+
}
|
|
3556
|
+
return aligned ? mmp->a_s : mmp->s;
|
|
2705
3557
|
}
|
|
2706
3558
|
|
|
2707
|
-
if (m <= 32 || n <= 32) {
|
|
3559
|
+
if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) {
|
|
2708
3560
|
return aligned ? mmp->a_s : mmp->s;
|
|
2709
3561
|
}
|
|
2710
|
-
if (m <= 64 || n <= 64) {
|
|
3562
|
+
if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) {
|
|
2711
3563
|
return aligned ? mmp->a_m : mmp->m;
|
|
2712
3564
|
}
|
|
2713
3565
|
return aligned ? mmp->a_l : mmp->l;
|
|
@@ -2742,6 +3594,33 @@ static void ggml_vk_matmul(
|
|
|
2742
3594
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
|
2743
3595
|
}
|
|
2744
3596
|
|
|
3597
|
+
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
|
|
3598
|
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
|
|
3599
|
+
|
|
3600
|
+
if (ctx->device->coopmat2) {
|
|
3601
|
+
if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) {
|
|
3602
|
+
return aligned ? mmp->a_l : mmp->l;
|
|
3603
|
+
}
|
|
3604
|
+
if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) {
|
|
3605
|
+
return aligned ? mmp->a_m : mmp->m;
|
|
3606
|
+
}
|
|
3607
|
+
return aligned ? mmp->a_s : mmp->s;
|
|
3608
|
+
}
|
|
3609
|
+
|
|
3610
|
+
if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
|
|
3611
|
+
return aligned ? mmp->a_s : mmp->s;
|
|
3612
|
+
}
|
|
3613
|
+
if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) {
|
|
3614
|
+
return aligned ? mmp->a_m : mmp->m;
|
|
3615
|
+
}
|
|
3616
|
+
return aligned ? mmp->a_l : mmp->l;
|
|
3617
|
+
}
|
|
3618
|
+
|
|
3619
|
+
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
|
|
3620
|
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
|
|
3621
|
+
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align;
|
|
3622
|
+
}
|
|
3623
|
+
|
|
2745
3624
|
static void ggml_vk_matmul_id(
|
|
2746
3625
|
ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
|
|
2747
3626
|
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
|
|
@@ -2812,13 +3691,15 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
|
|
|
2812
3691
|
elements = { ne, 1, 1 };
|
|
2813
3692
|
}
|
|
2814
3693
|
|
|
2815
|
-
|
|
3694
|
+
vk_op_unary_push_constants pc = {
|
|
2816
3695
|
(uint32_t)ne,
|
|
2817
3696
|
(uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
|
|
2818
3697
|
(uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
|
|
2819
3698
|
0,
|
|
2820
3699
|
0.0f, 0.0f,
|
|
3700
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
2821
3701
|
};
|
|
3702
|
+
init_pushconst_fastdiv(pc);
|
|
2822
3703
|
ggml_vk_sync_buffers(subctx);
|
|
2823
3704
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
|
|
2824
3705
|
}
|
|
@@ -2867,18 +3748,20 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
2867
3748
|
}
|
|
2868
3749
|
|
|
2869
3750
|
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
|
|
2870
|
-
|
|
3751
|
+
// Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
|
|
3752
|
+
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
|
3753
|
+
!ggml_vk_dim01_contiguous(src1);
|
|
2871
3754
|
|
|
2872
3755
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
|
2873
3756
|
|
|
2874
|
-
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
|
|
3757
|
+
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
|
|
2875
3758
|
|
|
2876
3759
|
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
|
2877
3760
|
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
|
2878
3761
|
|
|
2879
3762
|
if (qx_needs_dequant) {
|
|
2880
3763
|
// Fall back to dequant + f16 mulmat
|
|
2881
|
-
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
|
|
3764
|
+
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
|
|
2882
3765
|
}
|
|
2883
3766
|
|
|
2884
3767
|
// Not implemented
|
|
@@ -2891,10 +3774,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
2891
3774
|
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
|
|
2892
3775
|
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
|
2893
3776
|
|
|
2894
|
-
const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
|
2895
|
-
|
|
2896
3777
|
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
|
|
2897
3778
|
|
|
3779
|
+
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
|
|
3780
|
+
|
|
2898
3781
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
|
2899
3782
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
|
2900
3783
|
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
|
@@ -2920,7 +3803,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
2920
3803
|
if (dryrun) {
|
|
2921
3804
|
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
|
2922
3805
|
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
|
2923
|
-
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 *
|
|
3806
|
+
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
|
|
2924
3807
|
if (
|
|
2925
3808
|
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
|
|
2926
3809
|
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
|
|
@@ -3187,7 +4070,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
3187
4070
|
|
|
3188
4071
|
if (ne01 > max_groups_x) {
|
|
3189
4072
|
groups_z = 64;
|
|
3190
|
-
groups_x
|
|
4073
|
+
groups_x = CEIL_DIV(groups_x, groups_z);
|
|
3191
4074
|
}
|
|
3192
4075
|
|
|
3193
4076
|
// compute
|
|
@@ -3442,7 +4325,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
3442
4325
|
|
|
3443
4326
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
|
3444
4327
|
|
|
3445
|
-
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
|
|
4328
|
+
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
|
|
3446
4329
|
|
|
3447
4330
|
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
|
3448
4331
|
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
|
@@ -3458,10 +4341,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
3458
4341
|
const uint64_t y_ne = ne11 * ne10;
|
|
3459
4342
|
const uint64_t d_ne = ne21 * ne20;
|
|
3460
4343
|
|
|
3461
|
-
const uint32_t kpad = ggml_vk_align_size(ne10,
|
|
4344
|
+
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1));
|
|
3462
4345
|
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
|
3463
4346
|
|
|
3464
|
-
vk_pipeline pipeline =
|
|
4347
|
+
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned);
|
|
3465
4348
|
|
|
3466
4349
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
|
3467
4350
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
|
@@ -3764,7 +4647,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
3764
4647
|
|
|
3765
4648
|
if (ne01 > max_groups_x) {
|
|
3766
4649
|
groups_z = 64;
|
|
3767
|
-
groups_x
|
|
4650
|
+
groups_x = CEIL_DIV(groups_x, groups_z);
|
|
3768
4651
|
}
|
|
3769
4652
|
|
|
3770
4653
|
// compute
|
|
@@ -3789,6 +4672,167 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
3789
4672
|
}
|
|
3790
4673
|
}
|
|
3791
4674
|
|
|
4675
|
+
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
|
|
4676
|
+
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
|
|
4677
|
+
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
|
|
4678
|
+
std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
|
|
4679
|
+
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
|
4680
|
+
std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
|
|
4681
|
+
|
|
4682
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
4683
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
4684
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
4685
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
4686
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
4687
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
4688
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
4689
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
4690
|
+
|
|
4691
|
+
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
|
4692
|
+
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
|
|
4693
|
+
|
|
4694
|
+
const uint32_t D = neq0;
|
|
4695
|
+
const uint32_t N = neq1;
|
|
4696
|
+
const uint32_t KV = nek1;
|
|
4697
|
+
|
|
4698
|
+
GGML_ASSERT(ne0 == D);
|
|
4699
|
+
GGML_ASSERT(ne2 == N);
|
|
4700
|
+
|
|
4701
|
+
// input tensor rows must be contiguous
|
|
4702
|
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
|
4703
|
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
|
4704
|
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
|
4705
|
+
|
|
4706
|
+
GGML_ASSERT(neq0 == D);
|
|
4707
|
+
GGML_ASSERT(nek0 == D);
|
|
4708
|
+
GGML_ASSERT(nev0 == D);
|
|
4709
|
+
|
|
4710
|
+
GGML_ASSERT(neq1 == N);
|
|
4711
|
+
GGML_ASSERT(nev0 == D);
|
|
4712
|
+
|
|
4713
|
+
GGML_ASSERT(nev1 == nek1);
|
|
4714
|
+
|
|
4715
|
+
// dst cannot be transposed or permuted
|
|
4716
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
4717
|
+
GGML_ASSERT(nb0 <= nb1);
|
|
4718
|
+
GGML_ASSERT(nb1 <= nb2);
|
|
4719
|
+
GGML_ASSERT(nb2 <= nb3);
|
|
4720
|
+
|
|
4721
|
+
assert(dst->type == GGML_TYPE_F32);
|
|
4722
|
+
assert(q->type == GGML_TYPE_F32);
|
|
4723
|
+
assert(k->type == v->type);
|
|
4724
|
+
|
|
4725
|
+
vk_pipeline *pipelines;
|
|
4726
|
+
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
|
|
4727
|
+
bool f32acc = dst->op_params[3] == GGML_PREC_F32;
|
|
4728
|
+
bool small_rows = N <= flash_attention_num_small_rows;
|
|
4729
|
+
switch (D) {
|
|
4730
|
+
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
|
4731
|
+
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
|
4732
|
+
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
|
4733
|
+
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
|
4734
|
+
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
|
4735
|
+
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
|
4736
|
+
default:
|
|
4737
|
+
assert(!"unsupported D value");
|
|
4738
|
+
return;
|
|
4739
|
+
}
|
|
4740
|
+
assert(pipelines);
|
|
4741
|
+
|
|
4742
|
+
bool aligned = (KV % pipelines[1]->align) == 0;
|
|
4743
|
+
vk_pipeline pipeline = pipelines[aligned];
|
|
4744
|
+
assert(pipeline);
|
|
4745
|
+
|
|
4746
|
+
if (dryrun) {
|
|
4747
|
+
// Request descriptor sets
|
|
4748
|
+
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
|
4749
|
+
return;
|
|
4750
|
+
}
|
|
4751
|
+
|
|
4752
|
+
float scale = 1.0f;
|
|
4753
|
+
float max_bias = 0.0f;
|
|
4754
|
+
float logit_softcap = 0.0f;
|
|
4755
|
+
|
|
4756
|
+
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
|
4757
|
+
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
|
4758
|
+
memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
|
|
4759
|
+
|
|
4760
|
+
if (logit_softcap != 0) {
|
|
4761
|
+
scale /= logit_softcap;
|
|
4762
|
+
}
|
|
4763
|
+
|
|
4764
|
+
const uint32_t n_head_kv = neq2;
|
|
4765
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
|
4766
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
4767
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
4768
|
+
|
|
4769
|
+
ggml_vk_sync_buffers(subctx);
|
|
4770
|
+
|
|
4771
|
+
vk_buffer d_Q, d_K, d_V, d_D, d_M;
|
|
4772
|
+
uint64_t q_buf_offset, k_buf_offset, v_buf_offset, d_buf_offset, m_buf_offset;
|
|
4773
|
+
|
|
4774
|
+
bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
|
|
4775
|
+
|
|
4776
|
+
if (ctx->device->uma) {
|
|
4777
|
+
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
|
|
4778
|
+
ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset);
|
|
4779
|
+
ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset);
|
|
4780
|
+
ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset);
|
|
4781
|
+
Q_uma = d_Q != nullptr;
|
|
4782
|
+
K_uma = d_K != nullptr;
|
|
4783
|
+
V_uma = d_V != nullptr;
|
|
4784
|
+
D_uma = d_D != nullptr;
|
|
4785
|
+
if (mask) {
|
|
4786
|
+
ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset);
|
|
4787
|
+
M_uma = d_M != nullptr;
|
|
4788
|
+
}
|
|
4789
|
+
}
|
|
4790
|
+
|
|
4791
|
+
|
|
4792
|
+
ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
|
4793
|
+
ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context;
|
|
4794
|
+
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
|
|
4795
|
+
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
|
|
4796
|
+
|
|
4797
|
+
if (!Q_uma) {
|
|
4798
|
+
d_Q = q_buf_ctx->dev_buffer;
|
|
4799
|
+
q_buf_offset = vk_tensor_offset(q) + q->view_offs;
|
|
4800
|
+
}
|
|
4801
|
+
if (!K_uma) {
|
|
4802
|
+
d_K = k_buf_ctx->dev_buffer;
|
|
4803
|
+
k_buf_offset = vk_tensor_offset(k) + k->view_offs;
|
|
4804
|
+
}
|
|
4805
|
+
if (!V_uma) {
|
|
4806
|
+
d_V = v_buf_ctx->dev_buffer;
|
|
4807
|
+
v_buf_offset = vk_tensor_offset(v) + v->view_offs;
|
|
4808
|
+
}
|
|
4809
|
+
if (!D_uma) {
|
|
4810
|
+
d_D = d_buf_ctx->dev_buffer;
|
|
4811
|
+
d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
|
|
4812
|
+
}
|
|
4813
|
+
|
|
4814
|
+
if (!M_uma) {
|
|
4815
|
+
d_M = d_Q;
|
|
4816
|
+
m_buf_offset = q_buf_offset;
|
|
4817
|
+
if (mask) {
|
|
4818
|
+
ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context;
|
|
4819
|
+
d_M = m_buf_ctx->dev_buffer;
|
|
4820
|
+
m_buf_offset = vk_tensor_offset(mask) + mask->view_offs;
|
|
4821
|
+
}
|
|
4822
|
+
}
|
|
4823
|
+
|
|
4824
|
+
const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 };
|
|
4825
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
4826
|
+
{
|
|
4827
|
+
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
|
4828
|
+
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
|
4829
|
+
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
|
4830
|
+
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
|
4831
|
+
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
|
4832
|
+
},
|
|
4833
|
+
sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
|
|
4834
|
+
}
|
|
4835
|
+
|
|
3792
4836
|
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
|
|
3793
4837
|
switch (op) {
|
|
3794
4838
|
case GGML_OP_GET_ROWS:
|
|
@@ -3933,10 +4977,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
3933
4977
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
|
3934
4978
|
|
|
3935
4979
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
|
3936
|
-
return ctx->device->pipeline_soft_max_f32;
|
|
4980
|
+
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
|
3937
4981
|
}
|
|
3938
4982
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
|
3939
|
-
return ctx->device->pipeline_soft_max_f32_f16;
|
|
4983
|
+
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
|
|
3940
4984
|
}
|
|
3941
4985
|
return nullptr;
|
|
3942
4986
|
case GGML_OP_ROPE:
|
|
@@ -3989,6 +5033,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
3989
5033
|
return ctx->device->pipeline_pool2d_f32;
|
|
3990
5034
|
}
|
|
3991
5035
|
return nullptr;
|
|
5036
|
+
case GGML_OP_RWKV_WKV6:
|
|
5037
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5038
|
+
return ctx->device->pipeline_rwkv_wkv6_f32;
|
|
5039
|
+
}
|
|
5040
|
+
return nullptr;
|
|
3992
5041
|
case GGML_OP_LEAKY_RELU:
|
|
3993
5042
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
3994
5043
|
return ctx->device->pipeline_leaky_relu_f32;
|
|
@@ -4023,7 +5072,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
|
4023
5072
|
}
|
|
4024
5073
|
|
|
4025
5074
|
template<typename PC>
|
|
4026
|
-
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op,
|
|
5075
|
+
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
|
|
4027
5076
|
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
|
4028
5077
|
if (src1 != nullptr) {
|
|
4029
5078
|
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
|
@@ -4063,6 +5112,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
4063
5112
|
const uint64_t ned3 = dst->ne[3];
|
|
4064
5113
|
const uint64_t ned = ned0 * ned1;
|
|
4065
5114
|
|
|
5115
|
+
init_pushconst_fastdiv(pc);
|
|
5116
|
+
|
|
4066
5117
|
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
|
|
4067
5118
|
|
|
4068
5119
|
if (pipeline == nullptr) {
|
|
@@ -4389,6 +5440,134 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
4389
5440
|
}, dryrun);
|
|
4390
5441
|
}
|
|
4391
5442
|
|
|
5443
|
+
static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
|
|
5444
|
+
const ggml_tensor * k = dst->src[0];
|
|
5445
|
+
const ggml_tensor * v = dst->src[1];
|
|
5446
|
+
const ggml_tensor * r = dst->src[2];
|
|
5447
|
+
const ggml_tensor * tf = dst->src[3];
|
|
5448
|
+
const ggml_tensor * td = dst->src[4];
|
|
5449
|
+
const ggml_tensor * state = dst->src[5];
|
|
5450
|
+
|
|
5451
|
+
GGML_ASSERT(!ggml_is_quantized(k->type));
|
|
5452
|
+
GGML_ASSERT(!ggml_is_quantized(v->type));
|
|
5453
|
+
GGML_ASSERT(!ggml_is_quantized(r->type));
|
|
5454
|
+
GGML_ASSERT(!ggml_is_quantized(tf->type));
|
|
5455
|
+
GGML_ASSERT(!ggml_is_quantized(td->type));
|
|
5456
|
+
GGML_ASSERT(!ggml_is_quantized(state->type));
|
|
5457
|
+
GGML_ASSERT(dst->buffer != nullptr);
|
|
5458
|
+
|
|
5459
|
+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
|
|
5460
|
+
GGML_ASSERT(pipeline != nullptr);
|
|
5461
|
+
|
|
5462
|
+
if (dryrun) {
|
|
5463
|
+
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
|
5464
|
+
return;
|
|
5465
|
+
}
|
|
5466
|
+
|
|
5467
|
+
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
|
5468
|
+
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
|
|
5469
|
+
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
|
|
5470
|
+
ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
|
|
5471
|
+
ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
|
|
5472
|
+
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
|
|
5473
|
+
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
|
|
5474
|
+
|
|
5475
|
+
ggml_vk_sync_buffers(subctx);
|
|
5476
|
+
|
|
5477
|
+
vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
|
|
5478
|
+
uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
|
|
5479
|
+
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
|
|
5480
|
+
|
|
5481
|
+
if (ctx->device->uma) {
|
|
5482
|
+
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
|
|
5483
|
+
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
|
|
5484
|
+
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
|
|
5485
|
+
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
|
|
5486
|
+
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
|
|
5487
|
+
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
|
|
5488
|
+
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
|
5489
|
+
|
|
5490
|
+
K_uma = d_K != nullptr;
|
|
5491
|
+
V_uma = d_V != nullptr;
|
|
5492
|
+
R_uma = d_R != nullptr;
|
|
5493
|
+
TF_uma = d_TF != nullptr;
|
|
5494
|
+
TD_uma = d_TD != nullptr;
|
|
5495
|
+
STATE_uma = d_State != nullptr;
|
|
5496
|
+
DST_uma = d_D != nullptr;
|
|
5497
|
+
}
|
|
5498
|
+
|
|
5499
|
+
if (!K_uma) {
|
|
5500
|
+
d_K = k_buf_ctx->dev_buffer;
|
|
5501
|
+
k_offset = vk_tensor_offset(k) + k->view_offs;
|
|
5502
|
+
}
|
|
5503
|
+
if (!V_uma) {
|
|
5504
|
+
d_V = v_buf_ctx->dev_buffer;
|
|
5505
|
+
v_offset = vk_tensor_offset(v) + v->view_offs;
|
|
5506
|
+
}
|
|
5507
|
+
if (!R_uma) {
|
|
5508
|
+
d_R = r_buf_ctx->dev_buffer;
|
|
5509
|
+
r_offset = vk_tensor_offset(r) + r->view_offs;
|
|
5510
|
+
}
|
|
5511
|
+
if (!TF_uma) {
|
|
5512
|
+
d_TF = tf_buf_ctx->dev_buffer;
|
|
5513
|
+
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
|
|
5514
|
+
}
|
|
5515
|
+
if (!TD_uma) {
|
|
5516
|
+
d_TD = td_buf_ctx->dev_buffer;
|
|
5517
|
+
td_offset = vk_tensor_offset(td) + td->view_offs;
|
|
5518
|
+
}
|
|
5519
|
+
if (!STATE_uma) {
|
|
5520
|
+
d_State = state_buf_ctx->dev_buffer;
|
|
5521
|
+
state_offset = vk_tensor_offset(state) + state->view_offs;
|
|
5522
|
+
}
|
|
5523
|
+
if (!DST_uma) {
|
|
5524
|
+
d_D = dst_buf_ctx->dev_buffer;
|
|
5525
|
+
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
|
5526
|
+
}
|
|
5527
|
+
|
|
5528
|
+
const uint64_t k_size = ggml_nbytes(k);
|
|
5529
|
+
const uint64_t v_size = ggml_nbytes(v);
|
|
5530
|
+
const uint64_t r_size = ggml_nbytes(r);
|
|
5531
|
+
const uint64_t tf_size = ggml_nbytes(tf);
|
|
5532
|
+
const uint64_t td_size = ggml_nbytes(td);
|
|
5533
|
+
const uint64_t state_size = ggml_nbytes(state);
|
|
5534
|
+
const uint64_t dst_size = ggml_nbytes(dst);
|
|
5535
|
+
|
|
5536
|
+
std::array<uint32_t, 3> elements = {
|
|
5537
|
+
(uint32_t)(pc.B * pc.H),
|
|
5538
|
+
1,
|
|
5539
|
+
1
|
|
5540
|
+
};
|
|
5541
|
+
|
|
5542
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
|
5543
|
+
vk_subbuffer{ d_K, k_offset, k_size },
|
|
5544
|
+
vk_subbuffer{ d_V, v_offset, v_size },
|
|
5545
|
+
vk_subbuffer{ d_R, r_offset, r_size },
|
|
5546
|
+
vk_subbuffer{ d_TF, tf_offset, tf_size },
|
|
5547
|
+
vk_subbuffer{ d_TD, td_offset, td_size },
|
|
5548
|
+
vk_subbuffer{ d_State, state_offset, state_size },
|
|
5549
|
+
vk_subbuffer{ d_D, dst_offset, dst_size }
|
|
5550
|
+
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
|
5551
|
+
}
|
|
5552
|
+
|
|
5553
|
+
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
|
5554
|
+
const size_t seq_length = dst->src[0]->ne[3];
|
|
5555
|
+
const size_t n_embed = dst->ne[0];
|
|
5556
|
+
const size_t n_heads = dst->src[0]->ne[2];
|
|
5557
|
+
const size_t n_seqs = dst->src[5]->ne[1];
|
|
5558
|
+
|
|
5559
|
+
ggml_vk_op_f32_rwkv6(
|
|
5560
|
+
ctx, subctx, dst,
|
|
5561
|
+
{
|
|
5562
|
+
(uint32_t)n_seqs,
|
|
5563
|
+
(uint32_t)seq_length,
|
|
5564
|
+
(uint32_t)n_embed,
|
|
5565
|
+
(uint32_t)n_heads,
|
|
5566
|
+
},
|
|
5567
|
+
dryrun
|
|
5568
|
+
);
|
|
5569
|
+
}
|
|
5570
|
+
|
|
4392
5571
|
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
4393
5572
|
int * op_params = (int *)dst->op_params;
|
|
4394
5573
|
|
|
@@ -4432,7 +5611,8 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con
|
|
|
4432
5611
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
4433
5612
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
4434
5613
|
0,
|
|
4435
|
-
op_params[0], 0.0f
|
|
5614
|
+
op_params[0], 0.0f,
|
|
5615
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
4436
5616
|
}, dryrun);
|
|
4437
5617
|
}
|
|
4438
5618
|
|
|
@@ -4446,6 +5626,7 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
4446
5626
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
4447
5627
|
0,
|
|
4448
5628
|
0.0f, 0.0f,
|
|
5629
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
4449
5630
|
}, dryrun);
|
|
4450
5631
|
}
|
|
4451
5632
|
|
|
@@ -4459,6 +5640,7 @@ static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
4459
5640
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
4460
5641
|
0,
|
|
4461
5642
|
0.0f, 0.0f,
|
|
5643
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
4462
5644
|
}, dryrun);
|
|
4463
5645
|
}
|
|
4464
5646
|
|
|
@@ -4472,6 +5654,7 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
4472
5654
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
4473
5655
|
0,
|
|
4474
5656
|
0.0f, 0.0f,
|
|
5657
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
4475
5658
|
}, dryrun);
|
|
4476
5659
|
}
|
|
4477
5660
|
|
|
@@ -4486,6 +5669,7 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con
|
|
|
4486
5669
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
4487
5670
|
0,
|
|
4488
5671
|
op_params[0], op_params[1],
|
|
5672
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
4489
5673
|
}, dryrun);
|
|
4490
5674
|
}
|
|
4491
5675
|
|
|
@@ -4499,6 +5683,7 @@ static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
4499
5683
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
4500
5684
|
0,
|
|
4501
5685
|
0.0f, 0.0f,
|
|
5686
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
4502
5687
|
}, dryrun);
|
|
4503
5688
|
}
|
|
4504
5689
|
|
|
@@ -4512,6 +5697,7 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
4512
5697
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
4513
5698
|
0,
|
|
4514
5699
|
0.0f, 0.0f,
|
|
5700
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
4515
5701
|
}, dryrun);
|
|
4516
5702
|
}
|
|
4517
5703
|
|
|
@@ -4526,6 +5712,7 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
4526
5712
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
4527
5713
|
d_offset,
|
|
4528
5714
|
0.0f, 0.0f,
|
|
5715
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
4529
5716
|
}, dryrun);
|
|
4530
5717
|
}
|
|
4531
5718
|
|
|
@@ -4582,6 +5769,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
4582
5769
|
scale, max_bias,
|
|
4583
5770
|
m0, m1,
|
|
4584
5771
|
n_head_log2,
|
|
5772
|
+
nrows_x,
|
|
4585
5773
|
}, dryrun);
|
|
4586
5774
|
}
|
|
4587
5775
|
|
|
@@ -4878,19 +6066,27 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
|
4878
6066
|
for (size_t i = 0; i < x_ne; i++) {
|
|
4879
6067
|
if (std::is_same<float, X_TYPE>()) {
|
|
4880
6068
|
x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
|
6069
|
+
// x[i] = 1.0f;
|
|
6070
|
+
// x[i] = i + 1;
|
|
6071
|
+
// x[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
|
4881
6072
|
} else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
|
|
4882
6073
|
x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
|
|
6074
|
+
// x[i] = ggml_fp32_to_fp16(1.0f);
|
|
6075
|
+
// x[i] = ggml_fp32_to_fp16(i + 1);
|
|
6076
|
+
// x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
|
|
4883
6077
|
} else {
|
|
4884
6078
|
GGML_ABORT("fatal error");
|
|
4885
6079
|
}
|
|
4886
6080
|
}
|
|
4887
6081
|
for (size_t i = 0; i < y_ne; i++) {
|
|
4888
6082
|
if (std::is_same<float, Y_TYPE>()) {
|
|
4889
|
-
|
|
4890
|
-
y[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
|
6083
|
+
y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
|
6084
|
+
// y[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
|
6085
|
+
// y[i] = i + 1;
|
|
4891
6086
|
} else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
|
|
4892
|
-
|
|
4893
|
-
y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
|
|
6087
|
+
y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
|
|
6088
|
+
// y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
|
|
6089
|
+
// y[i] = ggml_fp32_to_fp16(i + 1);
|
|
4894
6090
|
} else {
|
|
4895
6091
|
GGML_ABORT("fatal error");
|
|
4896
6092
|
}
|
|
@@ -4900,16 +6096,16 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
|
4900
6096
|
ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
|
|
4901
6097
|
|
|
4902
6098
|
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
|
6099
|
+
ggml_vk_ctx_begin(ctx->device, subctx);
|
|
4903
6100
|
for (size_t i = 0; i < num_it; i++) {
|
|
4904
|
-
ggml_vk_ctx_begin(ctx->device, subctx);
|
|
4905
6101
|
ggml_vk_matmul(
|
|
4906
6102
|
ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
|
4907
6103
|
m, n, k,
|
|
4908
6104
|
k, k, m, k*m, k*n, m*n,
|
|
4909
6105
|
split_k, batch, batch, batch, 1, 1
|
|
4910
6106
|
);
|
|
4911
|
-
ggml_vk_ctx_end(subctx);
|
|
4912
6107
|
}
|
|
6108
|
+
ggml_vk_ctx_end(subctx);
|
|
4913
6109
|
|
|
4914
6110
|
auto begin = std::chrono::high_resolution_clock::now();
|
|
4915
6111
|
ggml_vk_submit(subctx, ctx->fence);
|
|
@@ -4974,7 +6170,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
|
4974
6170
|
double err = std::fabs(d[i] - d_chk[i]);
|
|
4975
6171
|
avg_err += err;
|
|
4976
6172
|
|
|
4977
|
-
if (err > 0.05f && first_err_n == -1) {
|
|
6173
|
+
if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
|
|
4978
6174
|
first_err_b = i / (m * n);
|
|
4979
6175
|
first_err_n = (i % (m * n)) / m;
|
|
4980
6176
|
first_err_m = (i % (m * n)) % m;
|
|
@@ -4987,12 +6183,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
|
4987
6183
|
|
|
4988
6184
|
std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
|
|
4989
6185
|
|
|
4990
|
-
if (avg_err > 0.1) {
|
|
6186
|
+
if (avg_err > 0.1 || std::isnan(avg_err)) {
|
|
4991
6187
|
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
|
|
4992
6188
|
std::cerr << "Actual result: " << std::endl << std::endl;
|
|
4993
6189
|
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
|
4994
|
-
std::cerr << std::endl;
|
|
4995
|
-
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
|
|
4996
6190
|
std::cerr << "Expected result: " << std::endl << std::endl;
|
|
4997
6191
|
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
|
4998
6192
|
|
|
@@ -5175,13 +6369,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
|
5175
6369
|
vk_pipeline p;
|
|
5176
6370
|
std::string shname;
|
|
5177
6371
|
if (shader_size == 0) {
|
|
5178
|
-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
|
|
6372
|
+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
|
|
5179
6373
|
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
|
|
5180
6374
|
} else if (shader_size == 1) {
|
|
5181
|
-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
|
|
6375
|
+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
|
|
5182
6376
|
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
|
|
5183
6377
|
} else if (shader_size == 2) {
|
|
5184
|
-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
|
|
6378
|
+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
|
|
5185
6379
|
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
|
|
5186
6380
|
} else {
|
|
5187
6381
|
GGML_ASSERT(0);
|
|
@@ -5191,13 +6385,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
|
5191
6385
|
|
|
5192
6386
|
if (k != kpad) {
|
|
5193
6387
|
if (shader_size == 0) {
|
|
5194
|
-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
|
|
6388
|
+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
|
|
5195
6389
|
shname = std::string(ggml_type_name(quant)) + "_S";
|
|
5196
6390
|
} else if (shader_size == 1) {
|
|
5197
|
-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
|
|
6391
|
+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
|
|
5198
6392
|
shname = std::string(ggml_type_name(quant)) + "_M";
|
|
5199
6393
|
} else if (shader_size == 2) {
|
|
5200
|
-
p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
|
|
6394
|
+
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
|
|
5201
6395
|
shname = std::string(ggml_type_name(quant)) + "_L";
|
|
5202
6396
|
} else {
|
|
5203
6397
|
GGML_ASSERT(0);
|
|
@@ -5247,16 +6441,16 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
|
5247
6441
|
ggml_vk_buffer_write(y_buf, 0, y, y_sz);
|
|
5248
6442
|
|
|
5249
6443
|
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
|
6444
|
+
ggml_vk_ctx_begin(ctx->device, subctx);
|
|
5250
6445
|
for (size_t i = 0; i < num_it; i++) {
|
|
5251
|
-
ggml_vk_ctx_begin(ctx->device, subctx);
|
|
5252
6446
|
ggml_vk_matmul(
|
|
5253
6447
|
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
|
5254
6448
|
m, n, k,
|
|
5255
6449
|
k, k, m, k*m, k*n, m*n,
|
|
5256
6450
|
split_k, batch, batch, batch, 1, 1
|
|
5257
6451
|
);
|
|
5258
|
-
ggml_vk_ctx_end(subctx);
|
|
5259
6452
|
}
|
|
6453
|
+
ggml_vk_ctx_end(subctx);
|
|
5260
6454
|
|
|
5261
6455
|
auto begin = std::chrono::high_resolution_clock::now();
|
|
5262
6456
|
|
|
@@ -5356,105 +6550,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
|
5356
6550
|
|
|
5357
6551
|
static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|
5358
6552
|
#if defined(GGML_VULKAN_RUN_TESTS)
|
|
5359
|
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
|
|
5360
|
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
|
|
5361
|
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
|
|
5362
|
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
|
|
5363
|
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
|
|
5364
|
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
|
|
5365
|
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
|
|
5366
|
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
|
|
5367
|
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
|
|
5368
|
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
|
|
5369
|
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
|
|
5370
|
-
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_IQ4_NL);
|
|
5371
|
-
|
|
5372
|
-
ggml_vk_test_matmul<ggml_fp16_t, ggml_fp16_t>(ctx, 512, 512, 100, 32, 100, 1, 2);
|
|
5373
|
-
|
|
5374
|
-
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
|
|
5375
|
-
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
|
|
5376
|
-
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
|
|
5377
|
-
// ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
|
|
5378
|
-
// ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
|
|
5379
|
-
// ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
|
|
5380
|
-
|
|
5381
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
|
|
5382
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
|
|
5383
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
|
|
5384
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
|
|
5385
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
|
|
5386
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
|
|
5387
|
-
|
|
5388
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
|
|
5389
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
|
|
5390
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
|
|
5391
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
|
|
5392
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
|
|
5393
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
|
|
5394
|
-
|
|
5395
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
|
|
5396
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
|
|
5397
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
|
|
5398
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
|
|
5399
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
|
|
5400
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
|
|
5401
|
-
|
|
5402
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
|
|
5403
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
|
|
5404
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
|
|
5405
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
|
|
5406
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
|
|
5407
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
|
|
5408
|
-
|
|
5409
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
|
|
5410
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
|
|
5411
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
|
|
5412
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
|
|
5413
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
|
|
5414
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
|
|
5415
|
-
|
|
5416
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q2_K);
|
|
5417
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q2_K);
|
|
5418
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q2_K);
|
|
5419
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q2_K);
|
|
5420
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q2_K);
|
|
5421
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q2_K);
|
|
5422
|
-
|
|
5423
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q3_K);
|
|
5424
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q3_K);
|
|
5425
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q3_K);
|
|
5426
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q3_K);
|
|
5427
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q3_K);
|
|
5428
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q3_K);
|
|
5429
|
-
|
|
5430
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_K);
|
|
5431
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_K);
|
|
5432
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_K);
|
|
5433
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_K);
|
|
5434
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_K);
|
|
5435
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_K);
|
|
5436
|
-
|
|
5437
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_K);
|
|
5438
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_K);
|
|
5439
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_K);
|
|
5440
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_K);
|
|
5441
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_K);
|
|
5442
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_K);
|
|
5443
|
-
|
|
5444
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q6_K);
|
|
5445
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q6_K);
|
|
5446
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q6_K);
|
|
5447
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q6_K);
|
|
5448
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q6_K);
|
|
5449
|
-
// ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q6_K);
|
|
5450
|
-
|
|
5451
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_IQ4_NL);
|
|
5452
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_IQ4_NL);
|
|
5453
|
-
ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_IQ4_NL);
|
|
5454
|
-
|
|
5455
|
-
std::cerr << std::endl;
|
|
5456
|
-
|
|
5457
6553
|
const std::vector<size_t> vals {
|
|
6554
|
+
512, 512, 128,
|
|
6555
|
+
128, 512, 512,
|
|
6556
|
+
4096, 512, 4096,
|
|
6557
|
+
11008, 512, 4096,
|
|
6558
|
+
4096, 512, 11008,
|
|
6559
|
+
32000, 512, 4096,
|
|
5458
6560
|
8, 8, 8,
|
|
5459
6561
|
100, 46, 576,
|
|
5460
6562
|
623, 111, 128,
|
|
@@ -5467,25 +6569,52 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|
|
5467
6569
|
49, 49, 128,
|
|
5468
6570
|
128, 49, 49,
|
|
5469
6571
|
4096, 49, 4096,
|
|
5470
|
-
11008, 49, 4096,
|
|
5471
|
-
4096, 49, 11008,
|
|
5472
|
-
32000, 49, 4096,
|
|
5473
|
-
512, 512, 128,
|
|
5474
|
-
128, 512, 512,
|
|
5475
|
-
4096, 512, 4096,
|
|
5476
|
-
11008, 512, 4096,
|
|
5477
|
-
4096, 512, 11008,
|
|
5478
|
-
32000, 512, 4096,
|
|
5479
6572
|
};
|
|
5480
|
-
const size_t num_it =
|
|
6573
|
+
const size_t num_it = 100;
|
|
6574
|
+
|
|
5481
6575
|
for (size_t i = 0; i < vals.size(); i += 3) {
|
|
5482
6576
|
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
|
|
5483
6577
|
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
|
|
5484
6578
|
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
|
|
5485
|
-
|
|
5486
|
-
|
|
5487
|
-
|
|
5488
|
-
|
|
6579
|
+
std::cerr << '\n';
|
|
6580
|
+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0);
|
|
6581
|
+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1);
|
|
6582
|
+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2);
|
|
6583
|
+
std::cerr << '\n';
|
|
6584
|
+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
|
|
6585
|
+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
|
|
6586
|
+
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
|
|
6587
|
+
std::cerr << '\n' << std::endl;
|
|
6588
|
+
|
|
6589
|
+
if (vals[i + 2] % 32 == 0) {
|
|
6590
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0);
|
|
6591
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0);
|
|
6592
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0);
|
|
6593
|
+
std::cerr << '\n';
|
|
6594
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0);
|
|
6595
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0);
|
|
6596
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0);
|
|
6597
|
+
std::cerr << '\n';
|
|
6598
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0);
|
|
6599
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0);
|
|
6600
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0);
|
|
6601
|
+
std::cerr << '\n' << std::endl;
|
|
6602
|
+
}
|
|
6603
|
+
|
|
6604
|
+
if (vals[i + 2] % 256 == 0) {
|
|
6605
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K);
|
|
6606
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K);
|
|
6607
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K);
|
|
6608
|
+
std::cerr << '\n';
|
|
6609
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K);
|
|
6610
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K);
|
|
6611
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K);
|
|
6612
|
+
std::cerr << '\n';
|
|
6613
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K);
|
|
6614
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K);
|
|
6615
|
+
ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K);
|
|
6616
|
+
std::cerr << '\n' << std::endl;
|
|
6617
|
+
}
|
|
5489
6618
|
}
|
|
5490
6619
|
|
|
5491
6620
|
GGML_ABORT("fatal error");
|
|
@@ -5532,6 +6661,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
5532
6661
|
const ggml_tensor * src0 = node->src[0];
|
|
5533
6662
|
const ggml_tensor * src1 = node->src[1];
|
|
5534
6663
|
const ggml_tensor * src2 = node->src[2];
|
|
6664
|
+
const ggml_tensor * src3 = node->src[3];
|
|
5535
6665
|
|
|
5536
6666
|
switch (node->op) {
|
|
5537
6667
|
// Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
|
|
@@ -5583,7 +6713,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
5583
6713
|
case GGML_OP_IM2COL:
|
|
5584
6714
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
5585
6715
|
case GGML_OP_POOL_2D:
|
|
6716
|
+
case GGML_OP_RWKV_WKV6:
|
|
5586
6717
|
case GGML_OP_LEAKY_RELU:
|
|
6718
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
|
5587
6719
|
break;
|
|
5588
6720
|
default:
|
|
5589
6721
|
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
|
@@ -5601,6 +6733,48 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
5601
6733
|
} else {
|
|
5602
6734
|
compute_ctx = ctx->compute_ctx.lock();
|
|
5603
6735
|
}
|
|
6736
|
+
} else {
|
|
6737
|
+
switch (node->op) {
|
|
6738
|
+
case GGML_OP_REPEAT:
|
|
6739
|
+
case GGML_OP_ACC:
|
|
6740
|
+
case GGML_OP_GET_ROWS:
|
|
6741
|
+
case GGML_OP_ADD:
|
|
6742
|
+
case GGML_OP_MUL:
|
|
6743
|
+
case GGML_OP_DIV:
|
|
6744
|
+
case GGML_OP_CONCAT:
|
|
6745
|
+
case GGML_OP_UPSCALE:
|
|
6746
|
+
case GGML_OP_SCALE:
|
|
6747
|
+
case GGML_OP_SQR:
|
|
6748
|
+
case GGML_OP_SIN:
|
|
6749
|
+
case GGML_OP_COS:
|
|
6750
|
+
case GGML_OP_CLAMP:
|
|
6751
|
+
case GGML_OP_PAD:
|
|
6752
|
+
case GGML_OP_CPY:
|
|
6753
|
+
case GGML_OP_CONT:
|
|
6754
|
+
case GGML_OP_DUP:
|
|
6755
|
+
case GGML_OP_NORM:
|
|
6756
|
+
case GGML_OP_GROUP_NORM:
|
|
6757
|
+
case GGML_OP_RMS_NORM:
|
|
6758
|
+
case GGML_OP_UNARY:
|
|
6759
|
+
case GGML_OP_DIAG_MASK_INF:
|
|
6760
|
+
case GGML_OP_SOFT_MAX:
|
|
6761
|
+
case GGML_OP_ROPE:
|
|
6762
|
+
case GGML_OP_ARGSORT:
|
|
6763
|
+
case GGML_OP_SUM_ROWS:
|
|
6764
|
+
case GGML_OP_IM2COL:
|
|
6765
|
+
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
6766
|
+
case GGML_OP_POOL_2D:
|
|
6767
|
+
case GGML_OP_LEAKY_RELU:
|
|
6768
|
+
{
|
|
6769
|
+
// These operations all go through ggml_vk_op_f32, so short-circuit and
|
|
6770
|
+
// do the only thing needed for the dryrun.
|
|
6771
|
+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
|
|
6772
|
+
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
|
6773
|
+
return false;
|
|
6774
|
+
}
|
|
6775
|
+
default:
|
|
6776
|
+
break;
|
|
6777
|
+
}
|
|
5604
6778
|
}
|
|
5605
6779
|
|
|
5606
6780
|
switch (node->op) {
|
|
@@ -5734,6 +6908,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
5734
6908
|
case GGML_OP_MUL_MAT_ID:
|
|
5735
6909
|
ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
|
5736
6910
|
|
|
6911
|
+
break;
|
|
6912
|
+
|
|
6913
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
|
6914
|
+
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
|
|
6915
|
+
|
|
6916
|
+
break;
|
|
6917
|
+
|
|
6918
|
+
case GGML_OP_RWKV_WKV6:
|
|
6919
|
+
ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
|
|
6920
|
+
|
|
5737
6921
|
break;
|
|
5738
6922
|
default:
|
|
5739
6923
|
return false;
|
|
@@ -5814,6 +6998,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
5814
6998
|
case GGML_OP_IM2COL:
|
|
5815
6999
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
5816
7000
|
case GGML_OP_POOL_2D:
|
|
7001
|
+
case GGML_OP_RWKV_WKV6:
|
|
5817
7002
|
case GGML_OP_LEAKY_RELU:
|
|
5818
7003
|
case GGML_OP_REPEAT:
|
|
5819
7004
|
buf = tensor->buffer;
|
|
@@ -5834,6 +7019,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
5834
7019
|
break;
|
|
5835
7020
|
case GGML_OP_MUL_MAT:
|
|
5836
7021
|
case GGML_OP_MUL_MAT_ID:
|
|
7022
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
|
5837
7023
|
buf = tensor->buffer;
|
|
5838
7024
|
|
|
5839
7025
|
break;
|
|
@@ -6330,16 +7516,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
6330
7516
|
bool first_node_in_batch = true; // true if next node will be first node in a batch
|
|
6331
7517
|
int submit_node_idx = 0; // index to first node in a batch
|
|
6332
7518
|
|
|
6333
|
-
//
|
|
6334
|
-
|
|
7519
|
+
// Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
|
|
7520
|
+
// Start with a smaller count to get work submitted right away, and increase it after each submit.
|
|
7521
|
+
int nodes_per_submit = 20;
|
|
6335
7522
|
int submitted_nodes = 0;
|
|
7523
|
+
int submit_count = 0;
|
|
6336
7524
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
6337
7525
|
if (first_node_in_batch) {
|
|
6338
7526
|
submit_node_idx = i;
|
|
6339
7527
|
}
|
|
6340
7528
|
|
|
6341
|
-
bool submit = (submitted_nodes >=
|
|
6342
|
-
|
|
7529
|
+
bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
|
|
6343
7530
|
|
|
6344
7531
|
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
|
|
6345
7532
|
|
|
@@ -6356,6 +7543,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
6356
7543
|
if (submit) {
|
|
6357
7544
|
first_node_in_batch = true;
|
|
6358
7545
|
submitted_nodes = 0;
|
|
7546
|
+
switch (submit_count) {
|
|
7547
|
+
case 0:
|
|
7548
|
+
nodes_per_submit = 50;
|
|
7549
|
+
break;
|
|
7550
|
+
default:
|
|
7551
|
+
nodes_per_submit = 100;
|
|
7552
|
+
break;
|
|
7553
|
+
}
|
|
7554
|
+
submit_count++;
|
|
6359
7555
|
}
|
|
6360
7556
|
}
|
|
6361
7557
|
|
|
@@ -6512,6 +7708,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
6512
7708
|
case GGML_OP_MUL_MAT:
|
|
6513
7709
|
case GGML_OP_MUL_MAT_ID:
|
|
6514
7710
|
{
|
|
7711
|
+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
7712
|
+
const vk_device& device = ggml_vk_get_device(ctx->device);
|
|
7713
|
+
if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) {
|
|
7714
|
+
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
|
7715
|
+
return false;
|
|
7716
|
+
}
|
|
6515
7717
|
switch (op->src[0]->type) {
|
|
6516
7718
|
case GGML_TYPE_F32:
|
|
6517
7719
|
case GGML_TYPE_F16:
|
|
@@ -6549,6 +7751,57 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
6549
7751
|
|
|
6550
7752
|
return true;
|
|
6551
7753
|
} break;
|
|
7754
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
|
7755
|
+
{
|
|
7756
|
+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
7757
|
+
if (!ggml_vk_get_device(ctx->device)->coopmat2) {
|
|
7758
|
+
return false;
|
|
7759
|
+
}
|
|
7760
|
+
switch (op->src[0]->ne[0]) {
|
|
7761
|
+
case 64:
|
|
7762
|
+
case 80:
|
|
7763
|
+
case 96:
|
|
7764
|
+
case 112:
|
|
7765
|
+
case 128:
|
|
7766
|
+
case 256:
|
|
7767
|
+
break;
|
|
7768
|
+
default:
|
|
7769
|
+
return false;
|
|
7770
|
+
}
|
|
7771
|
+
if (op->src[0]->type != GGML_TYPE_F32) {
|
|
7772
|
+
return false;
|
|
7773
|
+
}
|
|
7774
|
+
if (op->type != GGML_TYPE_F32) {
|
|
7775
|
+
return false;
|
|
7776
|
+
}
|
|
7777
|
+
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
|
7778
|
+
return false;
|
|
7779
|
+
}
|
|
7780
|
+
// It's straightforward to support different K/V dequant, but would
|
|
7781
|
+
// significantly increase the number of pipelines
|
|
7782
|
+
if (op->src[1]->type != op->src[2]->type) {
|
|
7783
|
+
return false;
|
|
7784
|
+
}
|
|
7785
|
+
switch (op->src[1]->type) {
|
|
7786
|
+
case GGML_TYPE_F16:
|
|
7787
|
+
case GGML_TYPE_Q4_0:
|
|
7788
|
+
case GGML_TYPE_Q4_1:
|
|
7789
|
+
case GGML_TYPE_Q5_0:
|
|
7790
|
+
case GGML_TYPE_Q5_1:
|
|
7791
|
+
case GGML_TYPE_Q8_0:
|
|
7792
|
+
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
|
7793
|
+
//case GGML_TYPE_Q2_K:
|
|
7794
|
+
//case GGML_TYPE_Q3_K:
|
|
7795
|
+
//case GGML_TYPE_Q4_K:
|
|
7796
|
+
//case GGML_TYPE_Q5_K:
|
|
7797
|
+
//case GGML_TYPE_Q6_K:
|
|
7798
|
+
case GGML_TYPE_IQ4_NL:
|
|
7799
|
+
break;
|
|
7800
|
+
default:
|
|
7801
|
+
return false;
|
|
7802
|
+
}
|
|
7803
|
+
return true;
|
|
7804
|
+
}
|
|
6552
7805
|
case GGML_OP_GET_ROWS:
|
|
6553
7806
|
{
|
|
6554
7807
|
switch (op->src[0]->type) {
|
|
@@ -6585,7 +7838,16 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
6585
7838
|
case GGML_OP_REPEAT:
|
|
6586
7839
|
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
|
|
6587
7840
|
case GGML_OP_ROPE:
|
|
6588
|
-
|
|
7841
|
+
{
|
|
7842
|
+
const int mode = ((const int32_t *) op->op_params)[2];
|
|
7843
|
+
if (mode & GGML_ROPE_TYPE_MROPE) {
|
|
7844
|
+
return false;
|
|
7845
|
+
}
|
|
7846
|
+
if (mode & GGML_ROPE_TYPE_VISION) {
|
|
7847
|
+
return false;
|
|
7848
|
+
}
|
|
7849
|
+
return ggml_is_contiguous(op->src[0]);
|
|
7850
|
+
}
|
|
6589
7851
|
case GGML_OP_NONE:
|
|
6590
7852
|
case GGML_OP_RESHAPE:
|
|
6591
7853
|
case GGML_OP_VIEW:
|
|
@@ -6613,6 +7875,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
6613
7875
|
case GGML_OP_IM2COL:
|
|
6614
7876
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
6615
7877
|
case GGML_OP_POOL_2D:
|
|
7878
|
+
case GGML_OP_RWKV_WKV6:
|
|
6616
7879
|
case GGML_OP_LEAKY_RELU:
|
|
6617
7880
|
return true;
|
|
6618
7881
|
default:
|
|
@@ -6709,8 +7972,9 @@ static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
|
|
|
6709
7972
|
|
|
6710
7973
|
ggml_backend_reg_t ggml_backend_vk_reg() {
|
|
6711
7974
|
static ggml_backend_reg reg = {
|
|
6712
|
-
/* .
|
|
6713
|
-
/* .
|
|
7975
|
+
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
|
7976
|
+
/* .iface = */ ggml_backend_vk_reg_i,
|
|
7977
|
+
/* .context = */ nullptr,
|
|
6714
7978
|
};
|
|
6715
7979
|
|
|
6716
7980
|
return ®
|
|
@@ -6862,6 +8126,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
6862
8126
|
ggml_tensor * src0 = tensor->src[0];
|
|
6863
8127
|
ggml_tensor * src1 = tensor->src[1];
|
|
6864
8128
|
ggml_tensor * src2 = tensor->src[2];
|
|
8129
|
+
ggml_tensor * src3 = tensor->src[3];
|
|
6865
8130
|
|
|
6866
8131
|
struct ggml_init_params iparams = {
|
|
6867
8132
|
/*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
|
|
@@ -6874,15 +8139,18 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
6874
8139
|
struct ggml_tensor * src0_clone = nullptr;
|
|
6875
8140
|
struct ggml_tensor * src1_clone = nullptr;
|
|
6876
8141
|
struct ggml_tensor * src2_clone = nullptr;
|
|
8142
|
+
struct ggml_tensor * src3_clone = nullptr;
|
|
6877
8143
|
struct ggml_tensor * tensor_clone = nullptr;
|
|
6878
8144
|
|
|
6879
8145
|
size_t src0_size;
|
|
6880
8146
|
size_t src1_size;
|
|
6881
8147
|
size_t src2_size;
|
|
8148
|
+
size_t src3_size;
|
|
6882
8149
|
|
|
6883
8150
|
void * src0_buffer = nullptr;
|
|
6884
8151
|
void * src1_buffer = nullptr;
|
|
6885
8152
|
void * src2_buffer = nullptr;
|
|
8153
|
+
void * src3_buffer = nullptr;
|
|
6886
8154
|
|
|
6887
8155
|
if (src0 != nullptr) {
|
|
6888
8156
|
src0_clone = ggml_dup_tensor(ggml_ctx, src0);
|
|
@@ -7010,8 +8278,53 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
7010
8278
|
ggml_vk_print_tensor(src2, "src2");
|
|
7011
8279
|
}
|
|
7012
8280
|
}
|
|
8281
|
+
if (src3 != nullptr) {
|
|
8282
|
+
src3_clone = ggml_dup_tensor(ggml_ctx, src3);
|
|
7013
8283
|
|
|
7014
|
-
|
|
8284
|
+
src3_size = ggml_nbytes(src3);
|
|
8285
|
+
|
|
8286
|
+
src3_buffer = malloc(src3_size);
|
|
8287
|
+
src3_clone->data = src3_buffer;
|
|
8288
|
+
if (ggml_backend_buffer_is_host(src3->buffer)) {
|
|
8289
|
+
memcpy(src3_clone->data, src3->data, src3_size);
|
|
8290
|
+
memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
|
8291
|
+
} else if (ggml_backend_buffer_is_vk(src3->buffer)) {
|
|
8292
|
+
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context;
|
|
8293
|
+
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
|
|
8294
|
+
uint64_t offset = vk_tensor_offset(src3) + src3->view_offs;
|
|
8295
|
+
if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) {
|
|
8296
|
+
for (int i3 = 0; i3 < src3->ne[3]; i3++) {
|
|
8297
|
+
for (int i2 = 0; i2 < src3->ne[2]; i2++) {
|
|
8298
|
+
const int idx = i3*src3->ne[2] + i2;
|
|
8299
|
+
ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]);
|
|
8300
|
+
}
|
|
8301
|
+
}
|
|
8302
|
+
|
|
8303
|
+
src3_clone->nb[0] = src3->nb[0];
|
|
8304
|
+
src3_clone->nb[1] = src3->nb[1];
|
|
8305
|
+
for (int i = 2; i < GGML_MAX_DIMS; i++) {
|
|
8306
|
+
src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1];
|
|
8307
|
+
}
|
|
8308
|
+
} else {
|
|
8309
|
+
if (offset + src3_size >= buffer_gpu->size) {
|
|
8310
|
+
src3_size = buffer_gpu->size - offset;
|
|
8311
|
+
}
|
|
8312
|
+
ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size);
|
|
8313
|
+
memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
|
8314
|
+
}
|
|
8315
|
+
} else {
|
|
8316
|
+
GGML_ABORT("fatal error");
|
|
8317
|
+
}
|
|
8318
|
+
|
|
8319
|
+
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
|
|
8320
|
+
ggml_vk_print_tensor(src3, "src3");
|
|
8321
|
+
}
|
|
8322
|
+
}
|
|
8323
|
+
|
|
8324
|
+
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
|
8325
|
+
const float *params = (const float *)tensor->op_params;
|
|
8326
|
+
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]);
|
|
8327
|
+
} else if (tensor->op == GGML_OP_MUL_MAT) {
|
|
7015
8328
|
tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
|
|
7016
8329
|
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
|
|
7017
8330
|
tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
|
|
@@ -7127,7 +8440,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
7127
8440
|
const int32_t max_period = tensor->op_params[1];
|
|
7128
8441
|
tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
|
|
7129
8442
|
} else if (tensor->op == GGML_OP_POOL_2D) {
|
|
7130
|
-
enum ggml_op_pool op = static_cast<ggml_op_pool>(
|
|
8443
|
+
enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
|
|
7131
8444
|
const int32_t k0 = tensor->op_params[1];
|
|
7132
8445
|
const int32_t k1 = tensor->op_params[2];
|
|
7133
8446
|
const int32_t s0 = tensor->op_params[3];
|
|
@@ -7139,7 +8452,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
7139
8452
|
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
|
|
7140
8453
|
const float * op_params = (const float *)tensor->op_params;
|
|
7141
8454
|
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
|
|
7142
|
-
} else {
|
|
8455
|
+
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
|
8456
|
+
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
|
|
8457
|
+
tensor->src[4], tensor->src[5]);
|
|
8458
|
+
}
|
|
8459
|
+
else {
|
|
7143
8460
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
|
7144
8461
|
GGML_ABORT("fatal error");
|
|
7145
8462
|
}
|
|
@@ -7336,3 +8653,5 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|
|
7336
8653
|
VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")");
|
|
7337
8654
|
}
|
|
7338
8655
|
#endif
|
|
8656
|
+
|
|
8657
|
+
GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg)
|