@fugood/llama.node 0.3.13 → 0.3.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +1 -1
- package/package.json +1 -1
- package/src/LlamaContext.cpp +98 -76
- package/src/LlamaContext.h +1 -1
- package/src/common.hpp +1 -2
- package/src/llama.cpp/.github/workflows/build.yml +89 -10
- package/src/llama.cpp/.github/workflows/server.yml +2 -0
- package/src/llama.cpp/CMakeLists.txt +9 -1
- package/src/llama.cpp/cmake/common.cmake +2 -0
- package/src/llama.cpp/common/CMakeLists.txt +3 -3
- package/src/llama.cpp/common/arg.cpp +132 -13
- package/src/llama.cpp/common/chat.cpp +960 -266
- package/src/llama.cpp/common/chat.h +135 -0
- package/src/llama.cpp/common/common.cpp +33 -174
- package/src/llama.cpp/common/common.h +27 -67
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
- package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
- package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
- package/src/llama.cpp/common/ngram-cache.cpp +1 -0
- package/src/llama.cpp/common/sampling.cpp +45 -7
- package/src/llama.cpp/common/speculative.cpp +10 -9
- package/src/llama.cpp/common/speculative.h +1 -1
- package/src/llama.cpp/docs/build.md +45 -7
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +4 -2
- package/src/llama.cpp/examples/embedding/embedding.cpp +2 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +3 -4
- package/src/llama.cpp/examples/infill/infill.cpp +2 -2
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +5 -5
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +373 -107
- package/src/llama.cpp/examples/llava/clip.h +19 -3
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
- package/src/llama.cpp/examples/llava/llava.cpp +4 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +7 -6
- package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
- package/src/llama.cpp/examples/main/main.cpp +79 -34
- package/src/llama.cpp/examples/parallel/parallel.cpp +6 -5
- package/src/llama.cpp/examples/passkey/passkey.cpp +15 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
- package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
- package/src/llama.cpp/examples/run/run.cpp +196 -108
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
- package/src/llama.cpp/examples/server/server.cpp +113 -101
- package/src/llama.cpp/examples/server/utils.hpp +94 -105
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
- package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +263 -151
- package/src/llama.cpp/ggml/CMakeLists.txt +14 -1
- package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
- package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
- package/src/llama.cpp/ggml/include/ggml.h +29 -1
- package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -34
- package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -7
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +139 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1546 -387
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1645 -113
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +242 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -6
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +315 -138
- package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +5 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +117 -36
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +147 -16
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +307 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +262 -746
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -78
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +4 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +498 -188
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +16 -3
- package/src/llama.cpp/ggml/src/ggml.c +93 -5
- package/src/llama.cpp/include/llama.h +105 -27
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +1 -0
- package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
- package/src/llama.cpp/requirements.txt +1 -0
- package/src/llama.cpp/src/CMakeLists.txt +5 -2
- package/src/llama.cpp/src/llama-adapter.cpp +19 -20
- package/src/llama.cpp/src/llama-adapter.h +11 -9
- package/src/llama.cpp/src/llama-arch.cpp +123 -16
- package/src/llama.cpp/src/llama-arch.h +19 -0
- package/src/llama.cpp/src/llama-batch.h +2 -2
- package/src/llama.cpp/src/llama-chat.cpp +1 -0
- package/src/llama.cpp/src/llama-context.cpp +2253 -1222
- package/src/llama.cpp/src/llama-context.h +214 -77
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-grammar.cpp +182 -182
- package/src/llama.cpp/src/llama-grammar.h +12 -3
- package/src/llama.cpp/src/llama-graph.cpp +1662 -0
- package/src/llama.cpp/src/llama-graph.h +574 -0
- package/src/llama.cpp/src/llama-hparams.cpp +8 -0
- package/src/llama.cpp/src/llama-hparams.h +9 -0
- package/src/llama.cpp/src/llama-io.cpp +15 -0
- package/src/llama.cpp/src/llama-io.h +35 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
- package/src/llama.cpp/src/llama-kv-cache.h +178 -109
- package/src/llama.cpp/src/llama-memory.cpp +1 -0
- package/src/llama.cpp/src/llama-memory.h +21 -0
- package/src/llama.cpp/src/llama-mmap.cpp +11 -1
- package/src/llama.cpp/src/llama-model.cpp +8230 -122
- package/src/llama.cpp/src/llama-model.h +34 -1
- package/src/llama.cpp/src/llama-quant.cpp +10 -1
- package/src/llama.cpp/src/llama-sampling.cpp +43 -10
- package/src/llama.cpp/src/llama-vocab.cpp +12 -0
- package/src/llama.cpp/src/llama.cpp +51 -9837
- package/src/llama.cpp/tests/test-backend-ops.cpp +247 -112
- package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
- package/src/llama.cpp/tests/test-chat.cpp +593 -395
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
- package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
- package/src/llama.cpp/Sources/llama/llama.h +0 -4
- package/src/llama.cpp/common/chat.hpp +0 -55
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
- /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
|
@@ -29,6 +29,7 @@
|
|
|
29
29
|
|
|
30
30
|
#include "ggml-vulkan-shaders.hpp"
|
|
31
31
|
|
|
32
|
+
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
|
32
33
|
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
|
33
34
|
|
|
34
35
|
#define VK_VENDOR_ID_AMD 0x1002
|
|
@@ -149,6 +150,66 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf);
|
|
|
149
150
|
|
|
150
151
|
static constexpr uint32_t mul_mat_vec_max_cols = 8;
|
|
151
152
|
|
|
153
|
+
enum vk_device_architecture {
|
|
154
|
+
OTHER,
|
|
155
|
+
AMD_GCN,
|
|
156
|
+
AMD_RDNA1,
|
|
157
|
+
AMD_RDNA2,
|
|
158
|
+
AMD_RDNA3,
|
|
159
|
+
};
|
|
160
|
+
|
|
161
|
+
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
|
162
|
+
vk::PhysicalDeviceProperties props = device.getProperties();
|
|
163
|
+
|
|
164
|
+
if (props.vendorID == VK_VENDOR_ID_AMD) {
|
|
165
|
+
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
|
166
|
+
|
|
167
|
+
bool amd_shader_core_properties = false;
|
|
168
|
+
bool integer_dot_product = false;
|
|
169
|
+
bool subgroup_size_control = false;
|
|
170
|
+
|
|
171
|
+
for (const auto& properties : ext_props) {
|
|
172
|
+
if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) {
|
|
173
|
+
amd_shader_core_properties = true;
|
|
174
|
+
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
|
|
175
|
+
integer_dot_product = true;
|
|
176
|
+
} else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
|
|
177
|
+
subgroup_size_control = true;
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
|
|
182
|
+
return vk_device_architecture::OTHER;
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
vk::PhysicalDeviceProperties2 props2;
|
|
186
|
+
vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
|
|
187
|
+
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
|
|
188
|
+
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
|
189
|
+
|
|
190
|
+
props2.pNext = &shader_core_props_amd;
|
|
191
|
+
shader_core_props_amd.pNext = &integer_dot_props;
|
|
192
|
+
integer_dot_props.pNext = &subgroup_size_control_props;
|
|
193
|
+
|
|
194
|
+
device.getProperties2(&props2);
|
|
195
|
+
|
|
196
|
+
if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {
|
|
197
|
+
return vk_device_architecture::AMD_GCN;
|
|
198
|
+
}
|
|
199
|
+
if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {
|
|
200
|
+
// RDNA
|
|
201
|
+
if (shader_core_props_amd.wavefrontsPerSimd == 20) {
|
|
202
|
+
return vk_device_architecture::AMD_RDNA1;
|
|
203
|
+
}
|
|
204
|
+
if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {
|
|
205
|
+
return vk_device_architecture::AMD_RDNA3;
|
|
206
|
+
}
|
|
207
|
+
return vk_device_architecture::AMD_RDNA2;
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
return vk_device_architecture::OTHER;
|
|
211
|
+
}
|
|
212
|
+
|
|
152
213
|
struct vk_device_struct {
|
|
153
214
|
std::mutex mutex;
|
|
154
215
|
|
|
@@ -161,6 +222,7 @@ struct vk_device_struct {
|
|
|
161
222
|
bool pipeline_robustness;
|
|
162
223
|
vk::Device device;
|
|
163
224
|
uint32_t vendor_id;
|
|
225
|
+
vk_device_architecture architecture;
|
|
164
226
|
vk_queue compute_queue;
|
|
165
227
|
vk_queue transfer_queue;
|
|
166
228
|
bool single_queue;
|
|
@@ -241,15 +303,20 @@ struct vk_device_struct {
|
|
|
241
303
|
vk_pipeline pipeline_norm_f32;
|
|
242
304
|
vk_pipeline pipeline_group_norm_f32;
|
|
243
305
|
vk_pipeline pipeline_rms_norm_f32;
|
|
306
|
+
vk_pipeline pipeline_rms_norm_back_f32;
|
|
307
|
+
vk_pipeline pipeline_l2_norm_f32;
|
|
244
308
|
vk_pipeline pipeline_gelu_f32;
|
|
245
309
|
vk_pipeline pipeline_gelu_quick_f32;
|
|
246
310
|
vk_pipeline pipeline_silu_f32;
|
|
311
|
+
vk_pipeline pipeline_silu_back_f32;
|
|
247
312
|
vk_pipeline pipeline_relu_f32;
|
|
248
313
|
vk_pipeline pipeline_leaky_relu_f32;
|
|
249
314
|
vk_pipeline pipeline_tanh_f32;
|
|
315
|
+
vk_pipeline pipeline_sigmoid_f32;
|
|
250
316
|
vk_pipeline pipeline_diag_mask_inf_f32;
|
|
251
317
|
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
|
252
318
|
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
|
|
319
|
+
vk_pipeline pipeline_soft_max_back_f32;
|
|
253
320
|
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
|
|
254
321
|
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
|
|
255
322
|
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
|
|
@@ -262,6 +329,7 @@ struct vk_device_struct {
|
|
|
262
329
|
vk_pipeline pipeline_timestep_embedding_f32;
|
|
263
330
|
vk_pipeline pipeline_pool2d_f32;
|
|
264
331
|
vk_pipeline pipeline_rwkv_wkv6_f32;
|
|
332
|
+
vk_pipeline pipeline_rwkv_wkv7_f32;
|
|
265
333
|
vk_pipeline pipeline_opt_step_adamw_f32;
|
|
266
334
|
|
|
267
335
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
|
@@ -364,6 +432,7 @@ struct vk_mat_mat_push_constants {
|
|
|
364
432
|
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
|
|
365
433
|
uint32_t k_split;
|
|
366
434
|
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
|
|
435
|
+
uint32_t padded_N;
|
|
367
436
|
};
|
|
368
437
|
struct vk_mat_vec_push_constants {
|
|
369
438
|
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
|
@@ -376,6 +445,7 @@ struct vk_mat_mat_id_push_constants {
|
|
|
376
445
|
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
|
377
446
|
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
|
|
378
447
|
uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
|
|
448
|
+
uint32_t padded_N;
|
|
379
449
|
};
|
|
380
450
|
struct vk_mat_vec_id_push_constants {
|
|
381
451
|
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
|
@@ -504,6 +574,7 @@ struct vk_op_rope_push_constants {
|
|
|
504
574
|
uint32_t s1;
|
|
505
575
|
uint32_t s2;
|
|
506
576
|
int32_t sections[4];
|
|
577
|
+
uint32_t is_back;
|
|
507
578
|
};
|
|
508
579
|
|
|
509
580
|
struct vk_op_soft_max_push_constants {
|
|
@@ -560,6 +631,13 @@ struct vk_op_rwkv_wkv6_push_constants {
|
|
|
560
631
|
uint32_t H;
|
|
561
632
|
};
|
|
562
633
|
|
|
634
|
+
struct vk_op_rwkv_wkv7_push_constants {
|
|
635
|
+
uint32_t B;
|
|
636
|
+
uint32_t T;
|
|
637
|
+
uint32_t C;
|
|
638
|
+
uint32_t H;
|
|
639
|
+
};
|
|
640
|
+
|
|
563
641
|
// Allow pre-recording command buffers
|
|
564
642
|
struct vk_staging_memcpy {
|
|
565
643
|
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
|
@@ -1440,6 +1518,73 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
1440
1518
|
return supported;
|
|
1441
1519
|
}
|
|
1442
1520
|
|
|
1521
|
+
struct GpuPipelineConfig {
|
|
1522
|
+
// GPU architecture identifier.
|
|
1523
|
+
// Example: vk_device_architecture::AMD_GCN
|
|
1524
|
+
vk_device_architecture arch;
|
|
1525
|
+
|
|
1526
|
+
// Mapping of pipeline names to their specific subgroup sizes.
|
|
1527
|
+
// Example: {"soft_max_f32", 64}
|
|
1528
|
+
std::unordered_map<std::string, uint32_t> pipelines;
|
|
1529
|
+
|
|
1530
|
+
// Default subgroup size for this GPU.
|
|
1531
|
+
// Defaults to 0 if not explicitly provided.
|
|
1532
|
+
uint32_t default_subgroup_size = 0;
|
|
1533
|
+
};
|
|
1534
|
+
|
|
1535
|
+
// Pipeline configuration for RDNA1 GPUs.
|
|
1536
|
+
static const std::unordered_map<std::string, uint32_t> rdna1_pipelines = {
|
|
1537
|
+
{"soft_max", 64}, {"im2col", 64},
|
|
1538
|
+
{"argmax", 64}, {"mul_mat_vec", 64},
|
|
1539
|
+
{"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32}
|
|
1540
|
+
};
|
|
1541
|
+
|
|
1542
|
+
// Pipeline configuration for RDNA2 GPUs.
|
|
1543
|
+
static const std::unordered_map<std::string, uint32_t> rdna2_pipelines = {
|
|
1544
|
+
{"soft_max", 64}, {"im2col", 64},
|
|
1545
|
+
};
|
|
1546
|
+
|
|
1547
|
+
static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32;
|
|
1548
|
+
|
|
1549
|
+
// Define configurations for different GPUs.
|
|
1550
|
+
static std::vector<GpuPipelineConfig> gpu_pipeline_configs = {
|
|
1551
|
+
{
|
|
1552
|
+
vk_device_architecture::AMD_RDNA1,
|
|
1553
|
+
{
|
|
1554
|
+
rdna1_pipelines,
|
|
1555
|
+
},
|
|
1556
|
+
RDNA_DEFAULT_SUBGROUP_SIZE
|
|
1557
|
+
},
|
|
1558
|
+
{
|
|
1559
|
+
vk_device_architecture::AMD_RDNA2,
|
|
1560
|
+
{
|
|
1561
|
+
rdna2_pipelines,
|
|
1562
|
+
},
|
|
1563
|
+
RDNA_DEFAULT_SUBGROUP_SIZE
|
|
1564
|
+
},
|
|
1565
|
+
};
|
|
1566
|
+
|
|
1567
|
+
static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) {
|
|
1568
|
+
for (const auto &config : gpu_pipeline_configs) {
|
|
1569
|
+
if (config.arch == arch) {
|
|
1570
|
+
auto pipIt = config.pipelines.find(pipeline_name);
|
|
1571
|
+
if (pipIt != config.pipelines.end()) {
|
|
1572
|
+
return pipIt->second;
|
|
1573
|
+
}
|
|
1574
|
+
std::vector<std::pair<std::string, uint32_t>> sorted_pipelines(config.pipelines.begin(), config.pipelines.end());
|
|
1575
|
+
std::sort(sorted_pipelines.begin(), sorted_pipelines.end(),
|
|
1576
|
+
[](const auto &a, const auto &b) { return a.first.size() > b.first.size(); });
|
|
1577
|
+
for (const auto &entry : sorted_pipelines) {
|
|
1578
|
+
if (pipeline_name.find(entry.first) != std::string::npos) {
|
|
1579
|
+
return entry.second;
|
|
1580
|
+
}
|
|
1581
|
+
}
|
|
1582
|
+
return config.default_subgroup_size;
|
|
1583
|
+
}
|
|
1584
|
+
}
|
|
1585
|
+
return 0; // If no matching configuration is found
|
|
1586
|
+
}
|
|
1587
|
+
|
|
1443
1588
|
static void ggml_vk_load_shaders(vk_device& device) {
|
|
1444
1589
|
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
|
1445
1590
|
|
|
@@ -1461,36 +1606,36 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1461
1606
|
uint32_t l_align, m_align, s_align;
|
|
1462
1607
|
if (device->coopmat2) {
|
|
1463
1608
|
// spec constants and tile sizes for non-quant matmul/matmul_id
|
|
1464
|
-
l_warptile = { 256, 128, 256, 64 };
|
|
1465
|
-
m_warptile = { 256, 128, 128, 64 };
|
|
1466
|
-
s_warptile = { 128, 64, 64, 64 };
|
|
1609
|
+
l_warptile = { 256, 128, 256, 64, 1 };
|
|
1610
|
+
m_warptile = { 256, 128, 128, 64, 0 };
|
|
1611
|
+
s_warptile = { 128, 64, 64, 64, 0 };
|
|
1467
1612
|
l_wg_denoms = {128, 256, 1 };
|
|
1468
1613
|
m_wg_denoms = {128, 128, 1 };
|
|
1469
1614
|
s_wg_denoms = { 64, 64, 1 };
|
|
1470
1615
|
|
|
1471
1616
|
// spec constants and tile sizes for quant matmul (non-Qi_K)
|
|
1472
|
-
l_warptile_mmq = { 256, 128, 256, 64 };
|
|
1473
|
-
m_warptile_mmq = { 256, 128, 128, 64 };
|
|
1474
|
-
s_warptile_mmq = { 256,
|
|
1617
|
+
l_warptile_mmq = { 256, 128, 256, 64, 1 };
|
|
1618
|
+
m_warptile_mmq = { 256, 128, 128, 64, 1 };
|
|
1619
|
+
s_warptile_mmq = { 256, 32, 64, 128, 0 };
|
|
1475
1620
|
l_mmq_wg_denoms = { 128, 256, 1 };
|
|
1476
1621
|
m_mmq_wg_denoms = { 128, 128, 1 };
|
|
1477
|
-
s_mmq_wg_denoms = {
|
|
1622
|
+
s_mmq_wg_denoms = { 32, 64, 1 };
|
|
1478
1623
|
|
|
1479
1624
|
// spec constants and tile sizes for quant matmul (Qi_K)
|
|
1480
|
-
l_warptile_mmq_k = { 256, 128,
|
|
1481
|
-
m_warptile_mmq_k = { 256,
|
|
1482
|
-
s_warptile_mmq_k = { 256, 32, 128,
|
|
1483
|
-
l_mmq_wg_denoms_k = {
|
|
1484
|
-
m_mmq_wg_denoms_k = {
|
|
1485
|
-
s_mmq_wg_denoms_k = { 32,
|
|
1625
|
+
l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
|
|
1626
|
+
m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
|
|
1627
|
+
s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
|
|
1628
|
+
l_mmq_wg_denoms_k = { 64, 128, 1 };
|
|
1629
|
+
m_mmq_wg_denoms_k = { 32, 64, 1 };
|
|
1630
|
+
s_mmq_wg_denoms_k = { 32, 32, 1 };
|
|
1486
1631
|
|
|
1487
1632
|
// spec constants and tile sizes for quant matmul_id
|
|
1488
|
-
l_warptile_mmqid = { 256, 128,
|
|
1489
|
-
m_warptile_mmqid = { 256, 128, 64, 16 };
|
|
1490
|
-
s_warptile_mmqid = { 256,
|
|
1491
|
-
l_mmqid_wg_denoms = { 128,
|
|
1633
|
+
l_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
1634
|
+
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
1635
|
+
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
|
1636
|
+
l_mmqid_wg_denoms = { 128, 64, 1 };
|
|
1492
1637
|
m_mmqid_wg_denoms = { 128, 64, 1 };
|
|
1493
|
-
s_mmqid_wg_denoms = {
|
|
1638
|
+
s_mmqid_wg_denoms = { 128, 64, 1 };
|
|
1494
1639
|
|
|
1495
1640
|
l_align = 128;
|
|
1496
1641
|
m_align = 64;
|
|
@@ -1566,6 +1711,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1566
1711
|
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
|
|
1567
1712
|
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
|
|
1568
1713
|
|
|
1714
|
+
if (!require_full_subgroups && required_subgroup_size == 0) {
|
|
1715
|
+
required_subgroup_size = get_subgroup_size(name, device->architecture);
|
|
1716
|
+
}
|
|
1717
|
+
|
|
1569
1718
|
if (!pipeline) {
|
|
1570
1719
|
pipeline = std::make_shared<vk_pipeline_struct>();
|
|
1571
1720
|
pipeline->name = name;
|
|
@@ -1987,6 +2136,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1987
2136
|
}
|
|
1988
2137
|
} else if (device->vendor_id == VK_VENDOR_ID_INTEL)
|
|
1989
2138
|
rm_stdq = 2;
|
|
2139
|
+
uint32_t rm_iq = 2 * rm_kq;
|
|
1990
2140
|
|
|
1991
2141
|
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
|
1992
2142
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
@@ -2001,15 +2151,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2001
2151
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
|
2002
2152
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
|
2003
2153
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
|
2004
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2005
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2006
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2007
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2008
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2009
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2010
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2011
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data,
|
|
2012
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2154
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2155
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2156
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2157
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2158
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2159
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2160
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2161
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2162
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2013
2163
|
|
|
2014
2164
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
2015
2165
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
@@ -2023,15 +2173,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2023
2173
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
|
2024
2174
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
|
2025
2175
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
|
2026
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2027
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2028
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2029
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2030
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2031
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2032
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2033
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data,
|
|
2034
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {
|
|
2176
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2177
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2178
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2179
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2180
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2181
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2182
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2183
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2184
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2035
2185
|
}
|
|
2036
2186
|
|
|
2037
2187
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
@@ -2046,15 +2196,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2046
2196
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
|
2047
2197
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
|
2048
2198
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
|
2049
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {
|
|
2050
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {
|
|
2051
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {
|
|
2052
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {
|
|
2053
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {
|
|
2054
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {
|
|
2055
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {
|
|
2056
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {
|
|
2057
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {
|
|
2199
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2200
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2201
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2202
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2203
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2204
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2205
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2206
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2207
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2058
2208
|
|
|
2059
2209
|
// dequant shaders
|
|
2060
2210
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
@@ -2121,6 +2271,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2121
2271
|
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2122
2272
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2123
2273
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2274
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2275
|
+
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2124
2276
|
|
|
2125
2277
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2126
2278
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -2180,9 +2332,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2180
2332
|
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2181
2333
|
ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2182
2334
|
ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2335
|
+
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2183
2336
|
ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2184
2337
|
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2185
2338
|
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2339
|
+
ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2186
2340
|
|
|
2187
2341
|
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
|
2188
2342
|
|
|
@@ -2190,6 +2344,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2190
2344
|
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
|
2191
2345
|
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
2192
2346
|
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
|
2347
|
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
2193
2348
|
|
|
2194
2349
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2195
2350
|
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
@@ -2229,6 +2384,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2229
2384
|
|
|
2230
2385
|
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
|
2231
2386
|
|
|
2387
|
+
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
|
2388
|
+
|
|
2232
2389
|
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2233
2390
|
|
|
2234
2391
|
for (auto &c : compiles) {
|
|
@@ -2237,7 +2394,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2237
2394
|
device->need_compiles = false;
|
|
2238
2395
|
}
|
|
2239
2396
|
|
|
2240
|
-
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
|
|
2397
|
+
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
|
|
2241
2398
|
|
|
2242
2399
|
static vk_device ggml_vk_get_device(size_t idx) {
|
|
2243
2400
|
VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
|
|
@@ -2266,6 +2423,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2266
2423
|
device->physical_device = physical_devices[dev_num];
|
|
2267
2424
|
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
|
2268
2425
|
|
|
2426
|
+
device->architecture = get_device_architecture(device->physical_device);
|
|
2427
|
+
|
|
2269
2428
|
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
|
|
2270
2429
|
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
|
|
2271
2430
|
|
|
@@ -2278,7 +2437,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2278
2437
|
bool coopmat2_support = false;
|
|
2279
2438
|
device->coopmat_support = false;
|
|
2280
2439
|
|
|
2281
|
-
// Check if maintenance4 is supported
|
|
2282
2440
|
for (const auto& properties : ext_props) {
|
|
2283
2441
|
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
|
2284
2442
|
maintenance4_support = true;
|
|
@@ -2366,13 +2524,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2366
2524
|
|
|
2367
2525
|
if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
|
|
2368
2526
|
device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
|
|
2369
|
-
|
|
2370
|
-
} else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) {
|
|
2527
|
+
} else {
|
|
2371
2528
|
// Limit batching of allocations to 1GB by default to avoid fragmentation issues
|
|
2372
2529
|
device->suballocation_block_size = 1024*1024*1024;
|
|
2373
|
-
#endif
|
|
2374
|
-
} else {
|
|
2375
|
-
device->suballocation_block_size = device->max_memory_allocation_size;
|
|
2376
2530
|
}
|
|
2377
2531
|
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
|
|
2378
2532
|
|
|
@@ -2391,7 +2545,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2391
2545
|
|
|
2392
2546
|
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
|
2393
2547
|
|
|
2394
|
-
if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) {
|
|
2548
|
+
if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) {
|
|
2395
2549
|
device->coopmat_support = false;
|
|
2396
2550
|
}
|
|
2397
2551
|
|
|
@@ -2769,7 +2923,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
2769
2923
|
subgroup_props.pNext = &driver_props;
|
|
2770
2924
|
physical_device.getProperties2(&props2);
|
|
2771
2925
|
|
|
2772
|
-
|
|
2926
|
+
vk_device_architecture arch = get_device_architecture(physical_device);
|
|
2927
|
+
uint32_t default_subgroup_size = get_subgroup_size("", arch);
|
|
2928
|
+
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
|
|
2929
|
+
|
|
2773
2930
|
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
|
2774
2931
|
|
|
2775
2932
|
bool fp16_storage = false;
|
|
@@ -2795,7 +2952,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
2795
2952
|
}
|
|
2796
2953
|
}
|
|
2797
2954
|
|
|
2798
|
-
|
|
2955
|
+
const vk_device_architecture device_architecture = get_device_architecture(physical_device);
|
|
2956
|
+
|
|
2957
|
+
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
|
|
2799
2958
|
coopmat_support = false;
|
|
2800
2959
|
}
|
|
2801
2960
|
|
|
@@ -3840,10 +3999,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
|
|
|
3840
3999
|
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
|
3841
4000
|
|
|
3842
4001
|
if (ctx->device->coopmat2) {
|
|
3843
|
-
|
|
4002
|
+
// Use large shader when the N dimension is greater than the medium shader's tile size
|
|
4003
|
+
uint32_t crossover_large = mmp->m->wg_denoms[1];
|
|
4004
|
+
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
|
|
3844
4005
|
return aligned ? mmp->a_l : mmp->l;
|
|
3845
4006
|
}
|
|
3846
|
-
|
|
4007
|
+
// Use medium shader when the N dimension is greater than the small shader's tile size
|
|
4008
|
+
uint32_t crossover_medium = mmp->s->wg_denoms[1];
|
|
4009
|
+
if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
|
|
3847
4010
|
return aligned ? mmp->a_m : mmp->m;
|
|
3848
4011
|
}
|
|
3849
4012
|
return aligned ? mmp->a_s : mmp->s;
|
|
@@ -3868,18 +4031,19 @@ static void ggml_vk_matmul(
|
|
|
3868
4031
|
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
|
|
3869
4032
|
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
|
|
3870
4033
|
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
|
|
3871
|
-
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3
|
|
4034
|
+
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
|
|
4035
|
+
uint32_t padded_n) {
|
|
3872
4036
|
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
|
|
3873
4037
|
ggml_vk_sync_buffers(subctx);
|
|
3874
4038
|
if (split_k == 1) {
|
|
3875
|
-
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
|
|
4039
|
+
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
|
3876
4040
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
|
|
3877
4041
|
return;
|
|
3878
4042
|
}
|
|
3879
4043
|
|
|
3880
4044
|
GGML_ASSERT(batch_stride_d == m * n);
|
|
3881
4045
|
|
|
3882
|
-
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 };
|
|
4046
|
+
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
|
|
3883
4047
|
// Make sure enough workgroups get assigned for split k to work
|
|
3884
4048
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
|
|
3885
4049
|
ggml_vk_sync_buffers(subctx);
|
|
@@ -3888,13 +4052,17 @@ static void ggml_vk_matmul(
|
|
|
3888
4052
|
}
|
|
3889
4053
|
|
|
3890
4054
|
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
|
|
3891
|
-
VK_LOG_DEBUG("
|
|
4055
|
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
|
3892
4056
|
|
|
3893
4057
|
if (ctx->device->coopmat2) {
|
|
3894
|
-
|
|
4058
|
+
// Use large shader when the N dimension is greater than the medium shader's tile size
|
|
4059
|
+
uint32_t crossover_large = mmp->m->wg_denoms[1];
|
|
4060
|
+
if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
|
|
3895
4061
|
return aligned ? mmp->a_l : mmp->l;
|
|
3896
4062
|
}
|
|
3897
|
-
|
|
4063
|
+
// Use medium shader when the N dimension is greater than the small shader's tile size
|
|
4064
|
+
uint32_t crossover_medium = mmp->s->wg_denoms[1];
|
|
4065
|
+
if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
|
|
3898
4066
|
return aligned ? mmp->a_m : mmp->m;
|
|
3899
4067
|
}
|
|
3900
4068
|
return aligned ? mmp->a_s : mmp->s;
|
|
@@ -3919,14 +4087,15 @@ static void ggml_vk_matmul_id(
|
|
|
3919
4087
|
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
|
|
3920
4088
|
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
|
|
3921
4089
|
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
|
|
3922
|
-
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11
|
|
4090
|
+
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
|
|
4091
|
+
uint32_t padded_n) {
|
|
3923
4092
|
VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
|
|
3924
4093
|
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
|
|
3925
4094
|
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
|
|
3926
4095
|
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
|
|
3927
4096
|
ggml_vk_sync_buffers(subctx);
|
|
3928
4097
|
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
|
|
3929
|
-
nei0, nei1, nbi1, ne11 };
|
|
4098
|
+
nei0, nei1, nbi1, ne11, padded_n };
|
|
3930
4099
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
|
|
3931
4100
|
}
|
|
3932
4101
|
|
|
@@ -4088,15 +4257,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4088
4257
|
// Not implemented
|
|
4089
4258
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
|
4090
4259
|
|
|
4091
|
-
const int x_ne = ne01 * ne00;
|
|
4092
|
-
const int y_ne = ne11 * ne10;
|
|
4093
|
-
const int d_ne = ne11 * ne01;
|
|
4094
|
-
|
|
4095
4260
|
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
|
4096
4261
|
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
|
4097
4262
|
|
|
4098
4263
|
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
|
4099
4264
|
|
|
4265
|
+
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
|
4266
|
+
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
|
4267
|
+
const int x_ne = ne01 * ne00;
|
|
4268
|
+
const int y_ne = padded_n * ne10;
|
|
4269
|
+
const int d_ne = ne11 * ne01;
|
|
4270
|
+
|
|
4100
4271
|
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
|
|
4101
4272
|
|
|
4102
4273
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
|
@@ -4183,7 +4354,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4183
4354
|
}
|
|
4184
4355
|
if (qy_needs_dequant) {
|
|
4185
4356
|
d_Y = ctx->prealloc_y;
|
|
4186
|
-
GGML_ASSERT(d_Y->size >= y_sz *
|
|
4357
|
+
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
|
|
4187
4358
|
} else {
|
|
4188
4359
|
d_Y = d_Qy;
|
|
4189
4360
|
y_buf_offset = qy_buf_offset;
|
|
@@ -4219,7 +4390,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4219
4390
|
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
|
|
4220
4391
|
ne01, ne11, ne10,
|
|
4221
4392
|
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
|
|
4222
|
-
split_k, ne12*ne13, ne02, ne12, r2, r3
|
|
4393
|
+
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
|
|
4223
4394
|
); // NOLINT
|
|
4224
4395
|
}
|
|
4225
4396
|
|
|
@@ -4670,15 +4841,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
4670
4841
|
// Not implemented
|
|
4671
4842
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
|
4672
4843
|
|
|
4673
|
-
const uint64_t x_ne = ne01 * ne00;
|
|
4674
|
-
const uint64_t y_ne = ne11 * ne10;
|
|
4675
|
-
const uint64_t d_ne = ne21 * ne20;
|
|
4676
|
-
|
|
4677
4844
|
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
|
4678
4845
|
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
|
4679
4846
|
|
|
4680
4847
|
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
|
4681
4848
|
|
|
4849
|
+
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
|
4850
|
+
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
|
4851
|
+
const uint64_t x_ne = ne01 * ne00;
|
|
4852
|
+
const uint64_t y_ne = padded_n * ne10;
|
|
4853
|
+
const uint64_t d_ne = ne21 * ne20;
|
|
4854
|
+
|
|
4682
4855
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
|
4683
4856
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
|
4684
4857
|
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
|
@@ -4760,7 +4933,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
4760
4933
|
}
|
|
4761
4934
|
if (qy_needs_dequant) {
|
|
4762
4935
|
d_Y = ctx->prealloc_y;
|
|
4763
|
-
GGML_ASSERT(d_Y->size >= y_sz *
|
|
4936
|
+
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
|
|
4764
4937
|
} else {
|
|
4765
4938
|
d_Y = d_Qy;
|
|
4766
4939
|
y_buf_offset = qy_buf_offset;
|
|
@@ -4797,7 +4970,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
4797
4970
|
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
|
|
4798
4971
|
ne01, ne21, ne10, ne10, ne10, ne01,
|
|
4799
4972
|
stride_batch_x, stride_batch_y, ne20*ne21,
|
|
4800
|
-
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11
|
|
4973
|
+
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
|
|
4801
4974
|
); // NOLINT
|
|
4802
4975
|
}
|
|
4803
4976
|
|
|
@@ -5283,6 +5456,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5283
5456
|
case GGML_OP_CONT:
|
|
5284
5457
|
case GGML_OP_DUP:
|
|
5285
5458
|
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
|
|
5459
|
+
case GGML_OP_SILU_BACK:
|
|
5460
|
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5461
|
+
return ctx->device->pipeline_silu_back_f32;
|
|
5462
|
+
}
|
|
5463
|
+
return nullptr;
|
|
5286
5464
|
case GGML_OP_NORM:
|
|
5287
5465
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5288
5466
|
return ctx->device->pipeline_norm_f32;
|
|
@@ -5298,6 +5476,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5298
5476
|
return ctx->device->pipeline_rms_norm_f32;
|
|
5299
5477
|
}
|
|
5300
5478
|
return nullptr;
|
|
5479
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
5480
|
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5481
|
+
return ctx->device->pipeline_rms_norm_back_f32;
|
|
5482
|
+
}
|
|
5483
|
+
return nullptr;
|
|
5484
|
+
case GGML_OP_L2_NORM:
|
|
5485
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5486
|
+
return ctx->device->pipeline_l2_norm_f32;
|
|
5487
|
+
}
|
|
5488
|
+
return nullptr;
|
|
5301
5489
|
case GGML_OP_UNARY:
|
|
5302
5490
|
switch (ggml_get_unary_op(dst)) {
|
|
5303
5491
|
case GGML_UNARY_OP_SILU:
|
|
@@ -5325,6 +5513,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5325
5513
|
return ctx->device->pipeline_tanh_f32;
|
|
5326
5514
|
}
|
|
5327
5515
|
break;
|
|
5516
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
5517
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5518
|
+
return ctx->device->pipeline_sigmoid_f32;
|
|
5519
|
+
}
|
|
5520
|
+
break;
|
|
5328
5521
|
default:
|
|
5329
5522
|
break;
|
|
5330
5523
|
}
|
|
@@ -5344,7 +5537,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5344
5537
|
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
|
|
5345
5538
|
}
|
|
5346
5539
|
return nullptr;
|
|
5540
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
5541
|
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5542
|
+
return ctx->device->pipeline_soft_max_back_f32;
|
|
5543
|
+
}
|
|
5544
|
+
return nullptr;
|
|
5347
5545
|
case GGML_OP_ROPE:
|
|
5546
|
+
case GGML_OP_ROPE_BACK:
|
|
5348
5547
|
{
|
|
5349
5548
|
const int mode = ((const int32_t *) dst->op_params)[2];
|
|
5350
5549
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
@@ -5426,6 +5625,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5426
5625
|
return ctx->device->pipeline_rwkv_wkv6_f32;
|
|
5427
5626
|
}
|
|
5428
5627
|
return nullptr;
|
|
5628
|
+
case GGML_OP_RWKV_WKV7:
|
|
5629
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5630
|
+
return ctx->device->pipeline_rwkv_wkv7_f32;
|
|
5631
|
+
}
|
|
5632
|
+
return nullptr;
|
|
5429
5633
|
case GGML_OP_OPT_STEP_ADAMW:
|
|
5430
5634
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5431
5635
|
return ctx->device->pipeline_opt_step_adamw_f32;
|
|
@@ -5672,7 +5876,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
5672
5876
|
switch (op) {
|
|
5673
5877
|
case GGML_OP_NORM:
|
|
5674
5878
|
case GGML_OP_RMS_NORM:
|
|
5879
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
5880
|
+
case GGML_OP_L2_NORM:
|
|
5675
5881
|
case GGML_OP_SOFT_MAX:
|
|
5882
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
5676
5883
|
case GGML_OP_SUM_ROWS:
|
|
5677
5884
|
case GGML_OP_ARGMAX:
|
|
5678
5885
|
{
|
|
@@ -5696,6 +5903,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
5696
5903
|
} break;
|
|
5697
5904
|
case GGML_OP_DIAG_MASK_INF:
|
|
5698
5905
|
case GGML_OP_ROPE:
|
|
5906
|
+
case GGML_OP_ROPE_BACK:
|
|
5699
5907
|
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
|
|
5700
5908
|
break;
|
|
5701
5909
|
case GGML_OP_GET_ROWS:
|
|
@@ -5791,7 +5999,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
5791
5999
|
|
|
5792
6000
|
ggml_vk_sync_buffers(subctx);
|
|
5793
6001
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
|
5794
|
-
} else if (op == GGML_OP_ROPE) {
|
|
6002
|
+
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
|
|
5795
6003
|
// Empty src2 is possible in rope, but the shader needs a buffer
|
|
5796
6004
|
vk_subbuffer subbuf_z;
|
|
5797
6005
|
if (use_src2) {
|
|
@@ -5919,23 +6127,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
5919
6127
|
}, dryrun);
|
|
5920
6128
|
}
|
|
5921
6129
|
|
|
5922
|
-
static void
|
|
5923
|
-
|
|
5924
|
-
|
|
5925
|
-
|
|
5926
|
-
|
|
5927
|
-
|
|
5928
|
-
|
|
5929
|
-
|
|
5930
|
-
GGML_ASSERT(!ggml_is_quantized(k->type));
|
|
5931
|
-
GGML_ASSERT(!ggml_is_quantized(v->type));
|
|
5932
|
-
GGML_ASSERT(!ggml_is_quantized(r->type));
|
|
5933
|
-
GGML_ASSERT(!ggml_is_quantized(tf->type));
|
|
5934
|
-
GGML_ASSERT(!ggml_is_quantized(td->type));
|
|
5935
|
-
GGML_ASSERT(!ggml_is_quantized(state->type));
|
|
6130
|
+
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
|
|
6131
|
+
GGML_ASSERT(version == 6 || version == 7);
|
|
6132
|
+
int num_srcs = version == 6 ? 6 : 7;
|
|
6133
|
+
|
|
6134
|
+
for (int i = 0; i < num_srcs; i++) {
|
|
6135
|
+
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
|
|
6136
|
+
}
|
|
6137
|
+
|
|
5936
6138
|
GGML_ASSERT(dst->buffer != nullptr);
|
|
5937
6139
|
|
|
5938
|
-
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx,
|
|
6140
|
+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
|
|
5939
6141
|
GGML_ASSERT(pipeline != nullptr);
|
|
5940
6142
|
|
|
5941
6143
|
if (dryrun) {
|
|
@@ -5944,89 +6146,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
|
|
|
5944
6146
|
}
|
|
5945
6147
|
|
|
5946
6148
|
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
|
5947
|
-
ggml_backend_vk_buffer_context *
|
|
5948
|
-
|
|
5949
|
-
|
|
5950
|
-
|
|
5951
|
-
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
|
|
5952
|
-
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
|
|
6149
|
+
ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
|
6150
|
+
for (int i = 0; i < num_srcs; i++) {
|
|
6151
|
+
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
|
|
6152
|
+
}
|
|
5953
6153
|
|
|
5954
6154
|
ggml_vk_sync_buffers(subctx);
|
|
5955
6155
|
|
|
5956
|
-
vk_buffer d_D = nullptr,
|
|
5957
|
-
size_t
|
|
5958
|
-
bool
|
|
6156
|
+
vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
|
6157
|
+
size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
|
6158
|
+
bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
|
|
5959
6159
|
|
|
5960
6160
|
if (ctx->device->uma) {
|
|
5961
|
-
|
|
5962
|
-
|
|
5963
|
-
|
|
5964
|
-
|
|
5965
|
-
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
|
|
5966
|
-
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
|
|
5967
|
-
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
|
6161
|
+
for (int i = 0; i < num_srcs; i++) {
|
|
6162
|
+
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
|
|
6163
|
+
srcs_uma[i] = d_srcs[i] != nullptr;
|
|
6164
|
+
}
|
|
5968
6165
|
|
|
5969
|
-
|
|
5970
|
-
|
|
5971
|
-
R_uma = d_R != nullptr;
|
|
5972
|
-
TF_uma = d_TF != nullptr;
|
|
5973
|
-
TD_uma = d_TD != nullptr;
|
|
5974
|
-
STATE_uma = d_State != nullptr;
|
|
5975
|
-
DST_uma = d_D != nullptr;
|
|
6166
|
+
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
|
6167
|
+
dst_uma = d_D != nullptr;
|
|
5976
6168
|
}
|
|
5977
6169
|
|
|
5978
|
-
|
|
5979
|
-
|
|
5980
|
-
|
|
5981
|
-
|
|
5982
|
-
|
|
5983
|
-
|
|
5984
|
-
|
|
5985
|
-
}
|
|
5986
|
-
if (!R_uma) {
|
|
5987
|
-
d_R = r_buf_ctx->dev_buffer;
|
|
5988
|
-
r_offset = vk_tensor_offset(r) + r->view_offs;
|
|
5989
|
-
}
|
|
5990
|
-
if (!TF_uma) {
|
|
5991
|
-
d_TF = tf_buf_ctx->dev_buffer;
|
|
5992
|
-
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
|
|
5993
|
-
}
|
|
5994
|
-
if (!TD_uma) {
|
|
5995
|
-
d_TD = td_buf_ctx->dev_buffer;
|
|
5996
|
-
td_offset = vk_tensor_offset(td) + td->view_offs;
|
|
5997
|
-
}
|
|
5998
|
-
if (!STATE_uma) {
|
|
5999
|
-
d_State = state_buf_ctx->dev_buffer;
|
|
6000
|
-
state_offset = vk_tensor_offset(state) + state->view_offs;
|
|
6170
|
+
uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
|
6171
|
+
for (int i = 0; i < num_srcs; i++) {
|
|
6172
|
+
src_sizes[i] = ggml_nbytes(dst->src[i]);
|
|
6173
|
+
if (!srcs_uma[i]) {
|
|
6174
|
+
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
|
|
6175
|
+
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
|
|
6176
|
+
}
|
|
6001
6177
|
}
|
|
6002
|
-
|
|
6178
|
+
|
|
6179
|
+
const uint64_t dst_size = ggml_nbytes(dst);
|
|
6180
|
+
if (!dst_uma) {
|
|
6003
6181
|
d_D = dst_buf_ctx->dev_buffer;
|
|
6004
6182
|
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
|
6005
6183
|
}
|
|
6006
6184
|
|
|
6007
|
-
const uint64_t k_size = ggml_nbytes(k);
|
|
6008
|
-
const uint64_t v_size = ggml_nbytes(v);
|
|
6009
|
-
const uint64_t r_size = ggml_nbytes(r);
|
|
6010
|
-
const uint64_t tf_size = ggml_nbytes(tf);
|
|
6011
|
-
const uint64_t td_size = ggml_nbytes(td);
|
|
6012
|
-
const uint64_t state_size = ggml_nbytes(state);
|
|
6013
|
-
const uint64_t dst_size = ggml_nbytes(dst);
|
|
6014
|
-
|
|
6015
6185
|
std::array<uint32_t, 3> elements = {
|
|
6016
6186
|
(uint32_t)(pc.B * pc.H),
|
|
6017
6187
|
1,
|
|
6018
6188
|
1
|
|
6019
6189
|
};
|
|
6020
6190
|
|
|
6021
|
-
|
|
6022
|
-
|
|
6023
|
-
|
|
6024
|
-
|
|
6025
|
-
|
|
6026
|
-
|
|
6027
|
-
|
|
6028
|
-
|
|
6029
|
-
|
|
6191
|
+
if (version == 6) {
|
|
6192
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
|
6193
|
+
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
|
6194
|
+
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
|
6195
|
+
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
|
6196
|
+
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
|
6197
|
+
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
|
6198
|
+
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
|
6199
|
+
vk_subbuffer{ d_D, dst_offset, dst_size }
|
|
6200
|
+
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
|
|
6201
|
+
} else if (version == 7) {
|
|
6202
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
|
6203
|
+
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
|
6204
|
+
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
|
6205
|
+
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
|
6206
|
+
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
|
6207
|
+
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
|
6208
|
+
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
|
6209
|
+
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
|
6210
|
+
vk_subbuffer{ d_D, dst_offset, dst_size }
|
|
6211
|
+
}, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
|
|
6212
|
+
} else {
|
|
6213
|
+
// shouldn't happen
|
|
6214
|
+
GGML_ASSERT(false);
|
|
6215
|
+
}
|
|
6030
6216
|
}
|
|
6031
6217
|
|
|
6032
6218
|
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
|
@@ -6035,7 +6221,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
6035
6221
|
const size_t n_heads = dst->src[0]->ne[1];
|
|
6036
6222
|
const size_t n_seqs = dst->src[5]->ne[1];
|
|
6037
6223
|
|
|
6038
|
-
|
|
6224
|
+
ggml_vk_op_f32_wkv(
|
|
6039
6225
|
ctx, subctx, dst,
|
|
6040
6226
|
{
|
|
6041
6227
|
(uint32_t)n_seqs,
|
|
@@ -6043,6 +6229,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
6043
6229
|
(uint32_t)n_embed,
|
|
6044
6230
|
(uint32_t)n_heads,
|
|
6045
6231
|
},
|
|
6232
|
+
6,
|
|
6233
|
+
dryrun
|
|
6234
|
+
);
|
|
6235
|
+
}
|
|
6236
|
+
|
|
6237
|
+
static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
|
6238
|
+
const size_t seq_length = dst->src[0]->ne[2];
|
|
6239
|
+
const size_t n_embed = dst->ne[0];
|
|
6240
|
+
const size_t n_heads = dst->src[0]->ne[1];
|
|
6241
|
+
const size_t n_seqs = dst->src[6]->ne[1];
|
|
6242
|
+
|
|
6243
|
+
ggml_vk_op_f32_wkv(
|
|
6244
|
+
ctx, subctx, dst,
|
|
6245
|
+
{
|
|
6246
|
+
(uint32_t)n_seqs,
|
|
6247
|
+
(uint32_t)seq_length,
|
|
6248
|
+
(uint32_t)n_embed,
|
|
6249
|
+
(uint32_t)n_heads,
|
|
6250
|
+
},
|
|
6251
|
+
7,
|
|
6046
6252
|
dryrun
|
|
6047
6253
|
);
|
|
6048
6254
|
}
|
|
@@ -6313,6 +6519,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
6313
6519
|
}, dryrun);
|
|
6314
6520
|
}
|
|
6315
6521
|
|
|
6522
|
+
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
6523
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
|
6524
|
+
}
|
|
6525
|
+
|
|
6316
6526
|
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
6317
6527
|
float * op_params = (float *)dst->op_params;
|
|
6318
6528
|
|
|
@@ -6335,6 +6545,16 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
6335
6545
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
|
6336
6546
|
}
|
|
6337
6547
|
|
|
6548
|
+
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
6549
|
+
float * op_params = (float *)dst->op_params;
|
|
6550
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
|
6551
|
+
}
|
|
6552
|
+
|
|
6553
|
+
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
6554
|
+
float * op_params = (float *)dst->op_params;
|
|
6555
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
|
6556
|
+
}
|
|
6557
|
+
|
|
6338
6558
|
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
6339
6559
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
|
6340
6560
|
}
|
|
@@ -6370,7 +6590,12 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
6370
6590
|
}, dryrun);
|
|
6371
6591
|
}
|
|
6372
6592
|
|
|
6373
|
-
static void
|
|
6593
|
+
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
6594
|
+
float * op_params = (float *)dst->op_params;
|
|
6595
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun);
|
|
6596
|
+
}
|
|
6597
|
+
|
|
6598
|
+
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
|
|
6374
6599
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
6375
6600
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
6376
6601
|
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
@@ -6398,7 +6623,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|
|
6398
6623
|
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
|
6399
6624
|
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
|
6400
6625
|
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
|
|
6401
|
-
sections[0], sections[1], sections[2], sections[3],
|
|
6626
|
+
sections[0], sections[1], sections[2], sections[3], backprop
|
|
6402
6627
|
}, dryrun);
|
|
6403
6628
|
}
|
|
6404
6629
|
|
|
@@ -6719,7 +6944,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
|
6719
6944
|
ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
|
6720
6945
|
m, n, k,
|
|
6721
6946
|
k, k, m, k*m, k*n, m*n,
|
|
6722
|
-
split_k, batch, batch, batch, 1, 1
|
|
6947
|
+
split_k, batch, batch, batch, 1, 1, n
|
|
6723
6948
|
);
|
|
6724
6949
|
}
|
|
6725
6950
|
ggml_vk_ctx_end(subctx);
|
|
@@ -7064,7 +7289,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
|
7064
7289
|
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
|
7065
7290
|
m, n, k,
|
|
7066
7291
|
k, k, m, k*m, k*n, m*n,
|
|
7067
|
-
split_k, batch, batch, batch, 1, 1
|
|
7292
|
+
split_k, batch, batch, batch, 1, 1, n
|
|
7068
7293
|
);
|
|
7069
7294
|
}
|
|
7070
7295
|
ggml_vk_ctx_end(subctx);
|
|
@@ -7295,6 +7520,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7295
7520
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
7296
7521
|
case GGML_UNARY_OP_RELU:
|
|
7297
7522
|
case GGML_UNARY_OP_TANH:
|
|
7523
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
7298
7524
|
break;
|
|
7299
7525
|
default:
|
|
7300
7526
|
return false;
|
|
@@ -7319,12 +7545,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7319
7545
|
case GGML_OP_CPY:
|
|
7320
7546
|
case GGML_OP_CONT:
|
|
7321
7547
|
case GGML_OP_DUP:
|
|
7548
|
+
case GGML_OP_SILU_BACK:
|
|
7322
7549
|
case GGML_OP_NORM:
|
|
7323
7550
|
case GGML_OP_GROUP_NORM:
|
|
7324
7551
|
case GGML_OP_RMS_NORM:
|
|
7552
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
7553
|
+
case GGML_OP_L2_NORM:
|
|
7325
7554
|
case GGML_OP_DIAG_MASK_INF:
|
|
7326
7555
|
case GGML_OP_SOFT_MAX:
|
|
7556
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
7327
7557
|
case GGML_OP_ROPE:
|
|
7558
|
+
case GGML_OP_ROPE_BACK:
|
|
7328
7559
|
case GGML_OP_MUL_MAT:
|
|
7329
7560
|
case GGML_OP_MUL_MAT_ID:
|
|
7330
7561
|
case GGML_OP_ARGSORT:
|
|
@@ -7336,6 +7567,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7336
7567
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
7337
7568
|
case GGML_OP_POOL_2D:
|
|
7338
7569
|
case GGML_OP_RWKV_WKV6:
|
|
7570
|
+
case GGML_OP_RWKV_WKV7:
|
|
7339
7571
|
case GGML_OP_LEAKY_RELU:
|
|
7340
7572
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
7341
7573
|
case GGML_OP_OPT_STEP_ADAMW:
|
|
@@ -7377,13 +7609,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7377
7609
|
case GGML_OP_CPY:
|
|
7378
7610
|
case GGML_OP_CONT:
|
|
7379
7611
|
case GGML_OP_DUP:
|
|
7612
|
+
case GGML_OP_SILU_BACK:
|
|
7380
7613
|
case GGML_OP_NORM:
|
|
7381
7614
|
case GGML_OP_GROUP_NORM:
|
|
7382
7615
|
case GGML_OP_RMS_NORM:
|
|
7616
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
7617
|
+
case GGML_OP_L2_NORM:
|
|
7383
7618
|
case GGML_OP_UNARY:
|
|
7384
7619
|
case GGML_OP_DIAG_MASK_INF:
|
|
7385
7620
|
case GGML_OP_SOFT_MAX:
|
|
7621
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
7386
7622
|
case GGML_OP_ROPE:
|
|
7623
|
+
case GGML_OP_ROPE_BACK:
|
|
7387
7624
|
case GGML_OP_ARGSORT:
|
|
7388
7625
|
case GGML_OP_SUM:
|
|
7389
7626
|
case GGML_OP_SUM_ROWS:
|
|
@@ -7475,6 +7712,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7475
7712
|
case GGML_OP_DUP:
|
|
7476
7713
|
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
|
|
7477
7714
|
|
|
7715
|
+
break;
|
|
7716
|
+
case GGML_OP_SILU_BACK:
|
|
7717
|
+
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
7718
|
+
|
|
7478
7719
|
break;
|
|
7479
7720
|
case GGML_OP_NORM:
|
|
7480
7721
|
ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
|
|
@@ -7487,6 +7728,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7487
7728
|
case GGML_OP_RMS_NORM:
|
|
7488
7729
|
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
|
|
7489
7730
|
|
|
7731
|
+
break;
|
|
7732
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
7733
|
+
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
7734
|
+
|
|
7735
|
+
break;
|
|
7736
|
+
case GGML_OP_L2_NORM:
|
|
7737
|
+
ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
|
|
7738
|
+
|
|
7490
7739
|
break;
|
|
7491
7740
|
case GGML_OP_UNARY:
|
|
7492
7741
|
switch (ggml_get_unary_op(node)) {
|
|
@@ -7495,6 +7744,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7495
7744
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
7496
7745
|
case GGML_UNARY_OP_RELU:
|
|
7497
7746
|
case GGML_UNARY_OP_TANH:
|
|
7747
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
7498
7748
|
ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
|
|
7499
7749
|
break;
|
|
7500
7750
|
default:
|
|
@@ -7508,9 +7758,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7508
7758
|
case GGML_OP_SOFT_MAX:
|
|
7509
7759
|
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
7510
7760
|
|
|
7761
|
+
break;
|
|
7762
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
7763
|
+
ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
7764
|
+
|
|
7511
7765
|
break;
|
|
7512
7766
|
case GGML_OP_ROPE:
|
|
7513
|
-
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
|
7767
|
+
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
|
|
7768
|
+
|
|
7769
|
+
break;
|
|
7770
|
+
case GGML_OP_ROPE_BACK:
|
|
7771
|
+
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun);
|
|
7514
7772
|
|
|
7515
7773
|
break;
|
|
7516
7774
|
case GGML_OP_ARGSORT:
|
|
@@ -7568,6 +7826,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7568
7826
|
|
|
7569
7827
|
break;
|
|
7570
7828
|
|
|
7829
|
+
case GGML_OP_RWKV_WKV7:
|
|
7830
|
+
ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
|
|
7831
|
+
|
|
7832
|
+
break;
|
|
7833
|
+
|
|
7571
7834
|
case GGML_OP_OPT_STEP_ADAMW:
|
|
7572
7835
|
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
|
7573
7836
|
|
|
@@ -7636,12 +7899,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
7636
7899
|
case GGML_OP_CPY:
|
|
7637
7900
|
case GGML_OP_CONT:
|
|
7638
7901
|
case GGML_OP_DUP:
|
|
7902
|
+
case GGML_OP_SILU_BACK:
|
|
7639
7903
|
case GGML_OP_NORM:
|
|
7640
7904
|
case GGML_OP_GROUP_NORM:
|
|
7641
7905
|
case GGML_OP_RMS_NORM:
|
|
7906
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
7907
|
+
case GGML_OP_L2_NORM:
|
|
7642
7908
|
case GGML_OP_DIAG_MASK_INF:
|
|
7643
7909
|
case GGML_OP_SOFT_MAX:
|
|
7910
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
7644
7911
|
case GGML_OP_ROPE:
|
|
7912
|
+
case GGML_OP_ROPE_BACK:
|
|
7645
7913
|
case GGML_OP_RESHAPE:
|
|
7646
7914
|
case GGML_OP_VIEW:
|
|
7647
7915
|
case GGML_OP_PERMUTE:
|
|
@@ -7656,6 +7924,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
7656
7924
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
7657
7925
|
case GGML_OP_POOL_2D:
|
|
7658
7926
|
case GGML_OP_RWKV_WKV6:
|
|
7927
|
+
case GGML_OP_RWKV_WKV7:
|
|
7659
7928
|
case GGML_OP_LEAKY_RELU:
|
|
7660
7929
|
case GGML_OP_REPEAT:
|
|
7661
7930
|
case GGML_OP_REPEAT_BACK:
|
|
@@ -7670,6 +7939,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
7670
7939
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
7671
7940
|
case GGML_UNARY_OP_RELU:
|
|
7672
7941
|
case GGML_UNARY_OP_TANH:
|
|
7942
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
7673
7943
|
buf = tensor->buffer;
|
|
7674
7944
|
break;
|
|
7675
7945
|
default:
|
|
@@ -7844,11 +8114,12 @@ static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
|
7844
8114
|
UNUSED(buffer);
|
|
7845
8115
|
}
|
|
7846
8116
|
|
|
7847
|
-
static
|
|
8117
|
+
static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
|
7848
8118
|
VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
|
|
7849
8119
|
if (tensor->view_src != nullptr) {
|
|
7850
8120
|
GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
|
|
7851
8121
|
}
|
|
8122
|
+
return GGML_STATUS_SUCCESS;
|
|
7852
8123
|
}
|
|
7853
8124
|
|
|
7854
8125
|
static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
|
@@ -8165,8 +8436,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
8165
8436
|
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
|
8166
8437
|
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
8167
8438
|
|
|
8439
|
+
uint64_t total_mat_mul_bytes = 0;
|
|
8168
8440
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
8169
8441
|
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
|
|
8442
|
+
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
|
8443
|
+
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
8444
|
+
}
|
|
8170
8445
|
}
|
|
8171
8446
|
if (ctx->device->need_compiles) {
|
|
8172
8447
|
ggml_vk_load_shaders(ctx->device);
|
|
@@ -8187,17 +8462,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
8187
8462
|
bool first_node_in_batch = true; // true if next node will be first node in a batch
|
|
8188
8463
|
int submit_node_idx = 0; // index to first node in a batch
|
|
8189
8464
|
|
|
8190
|
-
// Submit work
|
|
8191
|
-
//
|
|
8192
|
-
|
|
8465
|
+
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
|
|
8466
|
+
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
|
|
8467
|
+
// (and scaled down based on model size, so smaller models submit earlier).
|
|
8468
|
+
// Also submit at least every 100 nodes, in case there are workloads without as much matmul.
|
|
8469
|
+
int nodes_per_submit = 100;
|
|
8193
8470
|
int submitted_nodes = 0;
|
|
8194
8471
|
int submit_count = 0;
|
|
8472
|
+
uint64_t mul_mat_bytes = 0;
|
|
8473
|
+
uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u);
|
|
8195
8474
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
8196
8475
|
if (first_node_in_batch) {
|
|
8197
8476
|
submit_node_idx = i;
|
|
8198
8477
|
}
|
|
8199
8478
|
|
|
8200
|
-
|
|
8479
|
+
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
|
8480
|
+
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
8481
|
+
}
|
|
8482
|
+
|
|
8483
|
+
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
|
8484
|
+
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
|
8485
|
+
(i == last_node);
|
|
8201
8486
|
|
|
8202
8487
|
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
|
|
8203
8488
|
|
|
@@ -8214,13 +8499,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
8214
8499
|
if (submit) {
|
|
8215
8500
|
first_node_in_batch = true;
|
|
8216
8501
|
submitted_nodes = 0;
|
|
8217
|
-
|
|
8218
|
-
|
|
8219
|
-
|
|
8220
|
-
break;
|
|
8221
|
-
default:
|
|
8222
|
-
nodes_per_submit = 100;
|
|
8223
|
-
break;
|
|
8502
|
+
mul_mat_bytes = 0;
|
|
8503
|
+
if (submit_count < 3) {
|
|
8504
|
+
mul_mat_bytes_per_submit *= 2;
|
|
8224
8505
|
}
|
|
8225
8506
|
submit_count++;
|
|
8226
8507
|
}
|
|
@@ -8371,7 +8652,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8371
8652
|
case GGML_UNARY_OP_SILU:
|
|
8372
8653
|
case GGML_UNARY_OP_RELU:
|
|
8373
8654
|
case GGML_UNARY_OP_TANH:
|
|
8374
|
-
|
|
8655
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
8656
|
+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
|
8375
8657
|
default:
|
|
8376
8658
|
return false;
|
|
8377
8659
|
}
|
|
@@ -8560,6 +8842,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8560
8842
|
case GGML_OP_REPEAT_BACK:
|
|
8561
8843
|
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
|
8562
8844
|
case GGML_OP_ROPE:
|
|
8845
|
+
case GGML_OP_ROPE_BACK:
|
|
8563
8846
|
case GGML_OP_NONE:
|
|
8564
8847
|
case GGML_OP_RESHAPE:
|
|
8565
8848
|
case GGML_OP_VIEW:
|
|
@@ -8569,22 +8852,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8569
8852
|
case GGML_OP_NORM:
|
|
8570
8853
|
case GGML_OP_GROUP_NORM:
|
|
8571
8854
|
case GGML_OP_RMS_NORM:
|
|
8855
|
+
case GGML_OP_L2_NORM:
|
|
8572
8856
|
return ggml_is_contiguous(op->src[0]);
|
|
8573
8857
|
case GGML_OP_ADD:
|
|
8574
|
-
case GGML_OP_ACC:
|
|
8575
8858
|
case GGML_OP_SUB:
|
|
8576
8859
|
case GGML_OP_MUL:
|
|
8577
8860
|
case GGML_OP_DIV:
|
|
8578
|
-
case
|
|
8579
|
-
case
|
|
8580
|
-
case GGML_OP_SCALE:
|
|
8861
|
+
case GGML_OP_SILU_BACK:
|
|
8862
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
8581
8863
|
case GGML_OP_SQR:
|
|
8582
8864
|
case GGML_OP_SIN:
|
|
8583
8865
|
case GGML_OP_COS:
|
|
8584
8866
|
case GGML_OP_CLAMP:
|
|
8867
|
+
return op->src[0]->type == GGML_TYPE_F32;
|
|
8868
|
+
case GGML_OP_ACC:
|
|
8869
|
+
case GGML_OP_CONCAT:
|
|
8870
|
+
case GGML_OP_UPSCALE:
|
|
8871
|
+
case GGML_OP_SCALE:
|
|
8585
8872
|
case GGML_OP_PAD:
|
|
8586
8873
|
case GGML_OP_DIAG_MASK_INF:
|
|
8587
8874
|
case GGML_OP_SOFT_MAX:
|
|
8875
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
8588
8876
|
case GGML_OP_ARGSORT:
|
|
8589
8877
|
case GGML_OP_SUM:
|
|
8590
8878
|
case GGML_OP_SUM_ROWS:
|
|
@@ -8594,6 +8882,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8594
8882
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
8595
8883
|
case GGML_OP_POOL_2D:
|
|
8596
8884
|
case GGML_OP_RWKV_WKV6:
|
|
8885
|
+
case GGML_OP_RWKV_WKV7:
|
|
8597
8886
|
case GGML_OP_LEAKY_RELU:
|
|
8598
8887
|
case GGML_OP_OPT_STEP_ADAMW:
|
|
8599
8888
|
return true;
|
|
@@ -8740,7 +9029,7 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
|
|
|
8740
9029
|
UNUSED(instance_extensions);
|
|
8741
9030
|
}
|
|
8742
9031
|
|
|
8743
|
-
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
|
|
9032
|
+
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
|
8744
9033
|
switch (props.vendorID) {
|
|
8745
9034
|
case VK_VENDOR_ID_INTEL:
|
|
8746
9035
|
// Intel drivers don't support coopmat properly yet
|
|
@@ -8748,10 +9037,7 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
|
|
|
8748
9037
|
case VK_VENDOR_ID_AMD:
|
|
8749
9038
|
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
|
|
8750
9039
|
// Workaround for AMD proprietary driver reporting support on all GPUs
|
|
8751
|
-
|
|
8752
|
-
return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs
|
|
8753
|
-
name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs
|
|
8754
|
-
name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs
|
|
9040
|
+
return arch == vk_device_architecture::AMD_RDNA3;
|
|
8755
9041
|
}
|
|
8756
9042
|
return true;
|
|
8757
9043
|
default:
|
|
@@ -8976,15 +9262,25 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
8976
9262
|
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
|
|
8977
9263
|
} else if (tensor->op == GGML_OP_RMS_NORM) {
|
|
8978
9264
|
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
|
9265
|
+
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
|
|
9266
|
+
const float eps = ((float *) tensor->op_params)[0];
|
|
9267
|
+
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
|
|
9268
|
+
} else if (tensor->op == GGML_OP_SILU_BACK) {
|
|
9269
|
+
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
|
|
9270
|
+
} else if (tensor->op == GGML_OP_L2_NORM) {
|
|
9271
|
+
const float eps = ((float *) tensor->op_params)[0];
|
|
9272
|
+
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
|
|
8979
9273
|
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
|
8980
9274
|
if (src1 != nullptr) {
|
|
8981
9275
|
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
|
8982
9276
|
} else {
|
|
8983
9277
|
tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
|
|
8984
9278
|
}
|
|
9279
|
+
} else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
|
|
9280
|
+
tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
|
8985
9281
|
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
|
|
8986
9282
|
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
|
|
8987
|
-
} else if (tensor->op == GGML_OP_ROPE) {
|
|
9283
|
+
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
|
|
8988
9284
|
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
|
8989
9285
|
const int mode = ((int32_t *) tensor->op_params)[2];
|
|
8990
9286
|
//const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
|
|
@@ -8997,9 +9293,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
8997
9293
|
const float beta_slow = ((float *) tensor->op_params)[10];
|
|
8998
9294
|
if (mode & GGML_ROPE_TYPE_MROPE) {
|
|
8999
9295
|
int32_t *sections = ((int32_t *) tensor->op_params) + 11;
|
|
9000
|
-
|
|
9296
|
+
if (tensor->op == GGML_OP_ROPE) {
|
|
9297
|
+
tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
|
9298
|
+
} else {
|
|
9299
|
+
tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
|
9300
|
+
}
|
|
9001
9301
|
} else {
|
|
9002
|
-
|
|
9302
|
+
if (tensor->op == GGML_OP_ROPE) {
|
|
9303
|
+
tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
|
9304
|
+
} else {
|
|
9305
|
+
tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
|
9306
|
+
}
|
|
9003
9307
|
}
|
|
9004
9308
|
} else if (tensor->op == GGML_OP_UNARY) {
|
|
9005
9309
|
switch (ggml_get_unary_op(tensor)) {
|
|
@@ -9018,6 +9322,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
9018
9322
|
case GGML_UNARY_OP_TANH:
|
|
9019
9323
|
tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
|
|
9020
9324
|
break;
|
|
9325
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
9326
|
+
tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
|
|
9327
|
+
break;
|
|
9021
9328
|
default:
|
|
9022
9329
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
|
9023
9330
|
GGML_ABORT("fatal error");
|
|
@@ -9082,6 +9389,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
9082
9389
|
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
|
9083
9390
|
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
|
|
9084
9391
|
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
|
|
9392
|
+
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
|
|
9393
|
+
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
|
|
9394
|
+
src_clone[4], src_clone[5], src_clone[6]);
|
|
9085
9395
|
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
|
9086
9396
|
src_clone[0]->flags = src0->flags;
|
|
9087
9397
|
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|