@fugood/llama.node 0.3.12 → 0.3.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +2 -1
- package/package.json +1 -1
- package/src/LlamaCompletionWorker.cpp +14 -0
- package/src/LlamaContext.cpp +110 -79
- package/src/LlamaContext.h +1 -1
- package/src/common.hpp +1 -2
- package/src/llama.cpp/.github/workflows/build.yml +95 -13
- package/src/llama.cpp/.github/workflows/docker.yml +2 -0
- package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
- package/src/llama.cpp/.github/workflows/server.yml +2 -0
- package/src/llama.cpp/common/CMakeLists.txt +23 -6
- package/src/llama.cpp/common/arg.cpp +292 -14
- package/src/llama.cpp/common/chat.cpp +1128 -315
- package/src/llama.cpp/common/chat.h +135 -0
- package/src/llama.cpp/common/common.cpp +27 -171
- package/src/llama.cpp/common/common.h +41 -73
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
- package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
- package/src/llama.cpp/common/llguidance.cpp +3 -3
- package/src/llama.cpp/common/log.cpp +1 -0
- package/src/llama.cpp/common/log.h +2 -1
- package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +21 -7
- package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +61 -14
- package/src/llama.cpp/common/ngram-cache.cpp +1 -0
- package/src/llama.cpp/common/sampling.cpp +93 -49
- package/src/llama.cpp/common/speculative.cpp +6 -5
- package/src/llama.cpp/common/speculative.h +1 -1
- package/src/llama.cpp/docs/build.md +47 -9
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +4 -4
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +373 -107
- package/src/llama.cpp/examples/llava/clip.h +19 -3
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
- package/src/llama.cpp/examples/llava/llava.cpp +4 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
- package/src/llama.cpp/examples/main/main.cpp +73 -28
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
- package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
- package/src/llama.cpp/examples/run/run.cpp +115 -79
- package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/server/httplib.h +381 -292
- package/src/llama.cpp/examples/server/server.cpp +134 -128
- package/src/llama.cpp/examples/server/utils.hpp +95 -106
- package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +251 -142
- package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
- package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cpu.h +4 -1
- package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
- package/src/llama.cpp/ggml/include/ggml.h +6 -2
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
- package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
- package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +156 -11
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +2235 -641
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1572 -198
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +24 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +16 -3
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +246 -120
- package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +174 -728
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +949 -602
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +37 -3
- package/src/llama.cpp/ggml/src/ggml.c +9 -4
- package/src/llama.cpp/include/llama.h +32 -14
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +1 -0
- package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
- package/src/llama.cpp/requirements.txt +1 -0
- package/src/llama.cpp/src/llama-arch.cpp +21 -0
- package/src/llama.cpp/src/llama-arch.h +1 -0
- package/src/llama.cpp/src/llama-chat.cpp +1 -0
- package/src/llama.cpp/src/llama-grammar.cpp +183 -183
- package/src/llama.cpp/src/llama-grammar.h +13 -4
- package/src/llama.cpp/src/llama-impl.h +6 -6
- package/src/llama.cpp/src/llama-kv-cache.h +2 -1
- package/src/llama.cpp/src/llama-mmap.cpp +11 -1
- package/src/llama.cpp/src/llama-mmap.h +1 -0
- package/src/llama.cpp/src/llama-model.cpp +70 -6
- package/src/llama.cpp/src/llama-sampling.cpp +174 -67
- package/src/llama.cpp/src/llama-vocab.cpp +12 -0
- package/src/llama.cpp/src/llama.cpp +154 -5
- package/src/llama.cpp/src/unicode.cpp +9 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +171 -115
- package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
- package/src/llama.cpp/tests/test-chat.cpp +691 -325
- package/src/llama.cpp/tests/test-gguf.cpp +4 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
- package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
- package/src/llama.cpp/tests/test-sampling.cpp +15 -0
- package/src/llama.cpp/Sources/llama/llama.h +0 -4
- package/src/llama.cpp/common/chat.hpp +0 -52
|
@@ -167,6 +167,7 @@ struct vk_device_struct {
|
|
|
167
167
|
uint32_t subgroup_size;
|
|
168
168
|
uint32_t shader_core_count;
|
|
169
169
|
bool uma;
|
|
170
|
+
bool prefer_host_memory;
|
|
170
171
|
bool float_controls_rte_fp16;
|
|
171
172
|
|
|
172
173
|
bool subgroup_size_control;
|
|
@@ -184,12 +185,12 @@ struct vk_device_struct {
|
|
|
184
185
|
|
|
185
186
|
size_t idx;
|
|
186
187
|
|
|
187
|
-
bool mul_mat_l;
|
|
188
|
-
bool mul_mat_m;
|
|
189
|
-
bool mul_mat_s;
|
|
190
|
-
bool mul_mat_id_l;
|
|
191
|
-
bool mul_mat_id_m;
|
|
192
|
-
bool mul_mat_id_s;
|
|
188
|
+
bool mul_mat_l[GGML_TYPE_COUNT];
|
|
189
|
+
bool mul_mat_m[GGML_TYPE_COUNT];
|
|
190
|
+
bool mul_mat_s[GGML_TYPE_COUNT];
|
|
191
|
+
bool mul_mat_id_l[GGML_TYPE_COUNT];
|
|
192
|
+
bool mul_mat_id_m[GGML_TYPE_COUNT];
|
|
193
|
+
bool mul_mat_id_s[GGML_TYPE_COUNT];
|
|
193
194
|
|
|
194
195
|
// set to true to indicate that some shaders need to be compiled after the dryrun
|
|
195
196
|
bool need_compiles {};
|
|
@@ -221,6 +222,7 @@ struct vk_device_struct {
|
|
|
221
222
|
vk_pipeline pipeline_acc_f32;
|
|
222
223
|
vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
|
|
223
224
|
vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
|
|
225
|
+
vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
|
|
224
226
|
vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
|
|
225
227
|
vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
|
|
226
228
|
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
|
@@ -231,7 +233,7 @@ struct vk_device_struct {
|
|
|
231
233
|
vk_pipeline pipeline_cos_f32;
|
|
232
234
|
vk_pipeline pipeline_clamp_f32;
|
|
233
235
|
vk_pipeline pipeline_pad_f32;
|
|
234
|
-
vk_pipeline pipeline_repeat_f32;
|
|
236
|
+
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
|
235
237
|
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
|
|
236
238
|
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
|
|
237
239
|
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
|
@@ -239,23 +241,32 @@ struct vk_device_struct {
|
|
|
239
241
|
vk_pipeline pipeline_norm_f32;
|
|
240
242
|
vk_pipeline pipeline_group_norm_f32;
|
|
241
243
|
vk_pipeline pipeline_rms_norm_f32;
|
|
244
|
+
vk_pipeline pipeline_rms_norm_back_f32;
|
|
242
245
|
vk_pipeline pipeline_gelu_f32;
|
|
243
246
|
vk_pipeline pipeline_gelu_quick_f32;
|
|
244
247
|
vk_pipeline pipeline_silu_f32;
|
|
248
|
+
vk_pipeline pipeline_silu_back_f32;
|
|
245
249
|
vk_pipeline pipeline_relu_f32;
|
|
246
250
|
vk_pipeline pipeline_leaky_relu_f32;
|
|
247
251
|
vk_pipeline pipeline_tanh_f32;
|
|
252
|
+
vk_pipeline pipeline_sigmoid_f32;
|
|
248
253
|
vk_pipeline pipeline_diag_mask_inf_f32;
|
|
249
254
|
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
|
250
255
|
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
|
|
256
|
+
vk_pipeline pipeline_soft_max_back_f32;
|
|
251
257
|
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
|
|
252
258
|
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
|
|
259
|
+
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
|
|
260
|
+
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
|
253
261
|
vk_pipeline pipeline_argsort_f32;
|
|
254
262
|
vk_pipeline pipeline_sum_rows_f32;
|
|
263
|
+
vk_pipeline pipeline_argmax_f32;
|
|
264
|
+
vk_pipeline pipeline_count_equal_i32;
|
|
255
265
|
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
|
256
266
|
vk_pipeline pipeline_timestep_embedding_f32;
|
|
257
267
|
vk_pipeline pipeline_pool2d_f32;
|
|
258
268
|
vk_pipeline pipeline_rwkv_wkv6_f32;
|
|
269
|
+
vk_pipeline pipeline_opt_step_adamw_f32;
|
|
259
270
|
|
|
260
271
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
|
261
272
|
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
|
@@ -493,6 +504,11 @@ struct vk_op_rope_push_constants {
|
|
|
493
504
|
float corr_dims[2];
|
|
494
505
|
float theta_scale;
|
|
495
506
|
uint32_t has_ff;
|
|
507
|
+
uint32_t ne02;
|
|
508
|
+
uint32_t s1;
|
|
509
|
+
uint32_t s2;
|
|
510
|
+
int32_t sections[4];
|
|
511
|
+
uint32_t is_back;
|
|
496
512
|
};
|
|
497
513
|
|
|
498
514
|
struct vk_op_soft_max_push_constants {
|
|
@@ -1294,7 +1310,9 @@ static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk:
|
|
|
1294
1310
|
static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
|
|
1295
1311
|
vk_buffer buf;
|
|
1296
1312
|
try {
|
|
1297
|
-
if (device->
|
|
1313
|
+
if (device->prefer_host_memory) {
|
|
1314
|
+
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
1315
|
+
} else if (device->uma) {
|
|
1298
1316
|
// Fall back to host memory type
|
|
1299
1317
|
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
|
|
1300
1318
|
} else {
|
|
@@ -1378,7 +1396,37 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
|
|
|
1378
1396
|
return {64, 64};
|
|
1379
1397
|
};
|
|
1380
1398
|
|
|
1381
|
-
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id) {
|
|
1399
|
+
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
|
1400
|
+
|
|
1401
|
+
uint32_t lut_size = 0;
|
|
1402
|
+
switch (src0_type) {
|
|
1403
|
+
case GGML_TYPE_IQ1_S:
|
|
1404
|
+
case GGML_TYPE_IQ1_M:
|
|
1405
|
+
lut_size = 2*2048;
|
|
1406
|
+
break;
|
|
1407
|
+
case GGML_TYPE_IQ2_XXS:
|
|
1408
|
+
lut_size = 8*256;
|
|
1409
|
+
break;
|
|
1410
|
+
case GGML_TYPE_IQ2_XS:
|
|
1411
|
+
lut_size = 8*512;
|
|
1412
|
+
break;
|
|
1413
|
+
case GGML_TYPE_IQ2_S:
|
|
1414
|
+
lut_size = 8*1024;
|
|
1415
|
+
break;
|
|
1416
|
+
case GGML_TYPE_IQ3_XXS:
|
|
1417
|
+
lut_size = 4*256;
|
|
1418
|
+
break;
|
|
1419
|
+
case GGML_TYPE_IQ3_S:
|
|
1420
|
+
lut_size = 4*512;
|
|
1421
|
+
break;
|
|
1422
|
+
case GGML_TYPE_IQ4_NL:
|
|
1423
|
+
case GGML_TYPE_IQ4_XS:
|
|
1424
|
+
lut_size = 4*16;
|
|
1425
|
+
break;
|
|
1426
|
+
default:
|
|
1427
|
+
break;
|
|
1428
|
+
}
|
|
1429
|
+
|
|
1382
1430
|
// Needs to be kept up to date on shader changes
|
|
1383
1431
|
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
|
|
1384
1432
|
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
|
@@ -1388,13 +1436,20 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
1388
1436
|
const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
|
|
1389
1437
|
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
|
1390
1438
|
|
|
1391
|
-
|
|
1439
|
+
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
|
1440
|
+
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
1441
|
+
|
|
1442
|
+
VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
|
|
1443
|
+
"mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported);
|
|
1444
|
+
|
|
1445
|
+
return supported;
|
|
1392
1446
|
}
|
|
1393
1447
|
|
|
1394
1448
|
static void ggml_vk_load_shaders(vk_device& device) {
|
|
1395
1449
|
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
|
1396
1450
|
|
|
1397
1451
|
// some shaders have a minimum subgroup size
|
|
1452
|
+
const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
|
|
1398
1453
|
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
|
|
1399
1454
|
const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
|
|
1400
1455
|
|
|
@@ -1457,13 +1512,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1457
1512
|
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
|
|
1458
1513
|
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
|
|
1459
1514
|
|
|
1460
|
-
l_warptile = { 128, 128, 128, 16,
|
|
1461
|
-
m_warptile = { 128, 64, 64, 16,
|
|
1462
|
-
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s,
|
|
1515
|
+
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
|
|
1516
|
+
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
|
1517
|
+
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
|
1463
1518
|
|
|
1464
|
-
l_warptile_mmq = { 128, 128, 128, 32,
|
|
1465
|
-
m_warptile_mmq = { 128, 64, 64, 32,
|
|
1466
|
-
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s,
|
|
1519
|
+
l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
|
|
1520
|
+
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
|
1521
|
+
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
|
1467
1522
|
|
|
1468
1523
|
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
|
1469
1524
|
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
|
|
@@ -1472,62 +1527,32 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1472
1527
|
m_align = 64;
|
|
1473
1528
|
s_align = 32;
|
|
1474
1529
|
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1530
|
+
for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
|
|
1531
|
+
ggml_type t = (ggml_type)i;
|
|
1532
|
+
// Disable medium and large matrix multiplication if not enough shared memory is available
|
|
1533
|
+
// Check mmq warptiles as the largest configuration
|
|
1534
|
+
// Throw an error if not enough for any matrix multiplication is available
|
|
1535
|
+
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) {
|
|
1536
|
+
std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
|
|
1537
|
+
throw std::runtime_error("Shared memory size too small for matrix multiplication.");
|
|
1538
|
+
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) {
|
|
1539
|
+
device->mul_mat_m[i] = false;
|
|
1540
|
+
device->mul_mat_l[i] = false;
|
|
1541
|
+
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) {
|
|
1542
|
+
device->mul_mat_l[i] = false;
|
|
1543
|
+
}
|
|
1544
|
+
|
|
1545
|
+
// Disable mul_mat_id if not enough shared memory is available
|
|
1546
|
+
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) {
|
|
1547
|
+
device->mul_mat_id_s[i] = false;
|
|
1548
|
+
device->mul_mat_id_m[i] = false;
|
|
1549
|
+
device->mul_mat_id_l[i] = false;
|
|
1550
|
+
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) {
|
|
1551
|
+
device->mul_mat_id_m[i] = false;
|
|
1552
|
+
device->mul_mat_id_l[i] = false;
|
|
1553
|
+
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) {
|
|
1554
|
+
device->mul_mat_id_l[i] = false;
|
|
1500
1555
|
}
|
|
1501
|
-
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
|
|
1502
|
-
GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
|
|
1503
|
-
}
|
|
1504
|
-
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
|
|
1505
|
-
// assert mul_mat_mat_id shaders will fit.
|
|
1506
|
-
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
|
|
1507
|
-
}
|
|
1508
|
-
// Disable medium and large matrix multiplication if not enough shared memory is available
|
|
1509
|
-
// Check mmq warptiles as the largest configuration
|
|
1510
|
-
// Throw an error if not enough for any matrix multiplication is available
|
|
1511
|
-
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) {
|
|
1512
|
-
std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
|
|
1513
|
-
throw std::runtime_error("Shared memory size too small for matrix multiplication.");
|
|
1514
|
-
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) {
|
|
1515
|
-
device->mul_mat_m = false;
|
|
1516
|
-
device->mul_mat_l = false;
|
|
1517
|
-
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) {
|
|
1518
|
-
device->mul_mat_l = false;
|
|
1519
|
-
}
|
|
1520
|
-
|
|
1521
|
-
// Disable mul_mat_id if not enough shared memory is available
|
|
1522
|
-
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) {
|
|
1523
|
-
device->mul_mat_id_s = false;
|
|
1524
|
-
device->mul_mat_id_m = false;
|
|
1525
|
-
device->mul_mat_id_l = false;
|
|
1526
|
-
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) {
|
|
1527
|
-
device->mul_mat_id_m = false;
|
|
1528
|
-
device->mul_mat_id_l = false;
|
|
1529
|
-
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) {
|
|
1530
|
-
device->mul_mat_id_l = false;
|
|
1531
1556
|
}
|
|
1532
1557
|
}
|
|
1533
1558
|
|
|
@@ -1617,6 +1642,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1617
1642
|
//CREATE_FA(GGML_TYPE_Q4_K, q4_k)
|
|
1618
1643
|
//CREATE_FA(GGML_TYPE_Q5_K, q5_k)
|
|
1619
1644
|
//CREATE_FA(GGML_TYPE_Q6_K, q6_k)
|
|
1645
|
+
//CREATE_FA(GGML_TYPE_IQ1_S, iq1_s)
|
|
1646
|
+
//CREATE_FA(GGML_TYPE_IQ1_M, iq1_m)
|
|
1620
1647
|
//CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs)
|
|
1621
1648
|
//CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs)
|
|
1622
1649
|
//CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
|
|
@@ -1651,6 +1678,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1651
1678
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
|
1652
1679
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
|
1653
1680
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
|
1681
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1682
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1654
1683
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1655
1684
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1656
1685
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
@@ -1670,6 +1699,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1670
1699
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1671
1700
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1672
1701
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1702
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1703
|
+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1673
1704
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1674
1705
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1675
1706
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
@@ -1684,119 +1715,124 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1684
1715
|
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
1685
1716
|
if (device->coopmat_support) {
|
|
1686
1717
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
1687
|
-
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1688
|
-
if (device->mul_mat ## ID ## _l) \
|
|
1718
|
+
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1719
|
+
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
1689
1720
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
|
1690
|
-
if (device->mul_mat ## ID ## _m) \
|
|
1721
|
+
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
1691
1722
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
|
1692
|
-
if (device->mul_mat ## ID ## _s) \
|
|
1723
|
+
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
1693
1724
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
|
1694
|
-
if (device->mul_mat ## ID ## _l) \
|
|
1725
|
+
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
1695
1726
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
|
1696
|
-
if (device->mul_mat ## ID ## _m) \
|
|
1727
|
+
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
1697
1728
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
|
1698
|
-
if (device->mul_mat ## ID ## _s) \
|
|
1729
|
+
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
1699
1730
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
|
1700
1731
|
|
|
1701
1732
|
// Create 2 variants, {f16,f32} accumulator
|
|
1702
|
-
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1733
|
+
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1703
1734
|
if (device->coopmat_acc_f16_support) { \
|
|
1704
|
-
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1735
|
+
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1705
1736
|
} \
|
|
1706
1737
|
if (device->coopmat_acc_f32_support) { \
|
|
1707
|
-
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1738
|
+
CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1708
1739
|
} \
|
|
1709
1740
|
|
|
1710
|
-
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1711
|
-
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1712
|
-
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1713
|
-
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1741
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1742
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1743
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1744
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1714
1745
|
|
|
1715
1746
|
if (device->coopmat_acc_f16_support) {
|
|
1716
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1717
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1718
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1719
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1720
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1721
|
-
|
|
1722
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1723
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1724
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1725
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1726
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1727
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1728
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1729
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1730
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1731
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1732
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1733
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1747
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1748
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1749
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1750
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1751
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1752
|
+
|
|
1753
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1754
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1755
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1756
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1757
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1758
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1759
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1760
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1761
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1762
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1763
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1764
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1765
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1766
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1734
1767
|
} else {
|
|
1735
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1736
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1737
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1738
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1739
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1740
|
-
|
|
1741
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1742
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1743
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1744
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1745
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1746
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1747
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1748
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1749
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1750
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1751
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1752
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1753
|
-
|
|
1754
|
-
|
|
1755
|
-
|
|
1756
|
-
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
|
|
1766
|
-
|
|
1767
|
-
|
|
1768
|
-
|
|
1769
|
-
|
|
1770
|
-
|
|
1771
|
-
|
|
1772
|
-
|
|
1773
|
-
|
|
1774
|
-
|
|
1775
|
-
|
|
1776
|
-
|
|
1777
|
-
|
|
1778
|
-
|
|
1779
|
-
|
|
1780
|
-
|
|
1781
|
-
|
|
1782
|
-
|
|
1783
|
-
|
|
1784
|
-
|
|
1785
|
-
|
|
1786
|
-
|
|
1787
|
-
|
|
1788
|
-
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
|
|
1793
|
-
|
|
1794
|
-
|
|
1795
|
-
|
|
1796
|
-
|
|
1797
|
-
|
|
1798
|
-
|
|
1799
|
-
|
|
1768
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1769
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1770
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1771
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1772
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1773
|
+
|
|
1774
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1775
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1776
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1777
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1778
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1779
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1780
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1781
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1782
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1783
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1784
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1785
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1786
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1787
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1788
|
+
}
|
|
1789
|
+
|
|
1790
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1791
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1792
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1793
|
+
|
|
1794
|
+
if (device->coopmat_acc_f16_support) {
|
|
1795
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1796
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1797
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1798
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1799
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1800
|
+
|
|
1801
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1802
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1803
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1804
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1805
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1806
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1807
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1808
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1809
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1810
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1811
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1812
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1813
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1814
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1815
|
+
} else {
|
|
1816
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1817
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1818
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1819
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1820
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1821
|
+
|
|
1822
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1823
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1824
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1825
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1826
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1827
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1828
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1829
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1830
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1831
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1832
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1833
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1834
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1835
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1800
1836
|
}
|
|
1801
1837
|
#undef CREATE_MM2
|
|
1802
1838
|
#undef CREATE_MM
|
|
@@ -1804,141 +1840,143 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1804
1840
|
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
1805
1841
|
if (device->fp16) {
|
|
1806
1842
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
1807
|
-
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1808
|
-
if (device->mul_mat ## ID ## _l) \
|
|
1843
|
+
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1844
|
+
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
1809
1845
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
1810
|
-
if (device->mul_mat ## ID ## _m) \
|
|
1846
|
+
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
1811
1847
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
1812
|
-
if (device->mul_mat ## ID ## _s) \
|
|
1848
|
+
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
1813
1849
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
1814
|
-
if (device->mul_mat ## ID ## _l) \
|
|
1850
|
+
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
1815
1851
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
|
1816
|
-
if (device->mul_mat ## ID ## _m) \
|
|
1852
|
+
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
1817
1853
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
|
1818
|
-
if (device->mul_mat ## ID ## _s) \
|
|
1854
|
+
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
1819
1855
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
|
1820
1856
|
|
|
1821
1857
|
// Create 2 variants, {f16,f32} accumulator
|
|
1822
|
-
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1823
|
-
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1824
|
-
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1825
|
-
|
|
1826
|
-
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1827
|
-
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1828
|
-
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1829
|
-
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1830
|
-
|
|
1831
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1832
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1833
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1834
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1835
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1836
|
-
|
|
1837
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1838
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1839
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1840
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1841
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1842
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1843
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1844
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1845
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1846
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1847
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1848
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1849
|
-
|
|
1850
|
-
|
|
1851
|
-
|
|
1852
|
-
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
|
-
|
|
1856
|
-
|
|
1857
|
-
|
|
1858
|
-
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
|
|
1865
|
-
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
|
|
1869
|
-
|
|
1870
|
-
|
|
1871
|
-
|
|
1872
|
-
|
|
1873
|
-
|
|
1874
|
-
|
|
1858
|
+
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1859
|
+
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1860
|
+
CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1861
|
+
|
|
1862
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1863
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1864
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1865
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1866
|
+
|
|
1867
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1868
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1869
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1870
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1871
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1872
|
+
|
|
1873
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1874
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1875
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1876
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1877
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1878
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1879
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1880
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1881
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1882
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1883
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1884
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1885
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1886
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1887
|
+
|
|
1888
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1889
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1890
|
+
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1891
|
+
|
|
1892
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1893
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1894
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1895
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1896
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1897
|
+
|
|
1898
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1899
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1900
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1901
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1902
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1903
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1904
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1905
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1906
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1907
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1908
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1909
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1910
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1911
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1875
1912
|
#undef CREATE_MM2
|
|
1876
1913
|
#undef CREATE_MM
|
|
1877
1914
|
} else {
|
|
1878
1915
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
1879
|
-
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1880
|
-
if (device->mul_mat ## ID ## _l) \
|
|
1916
|
+
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1917
|
+
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
1881
1918
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
1882
|
-
if (device->mul_mat ## ID ## _m) \
|
|
1919
|
+
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
1883
1920
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
1884
|
-
if (device->mul_mat ## ID ## _s) \
|
|
1921
|
+
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
1885
1922
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
1886
|
-
if (device->mul_mat ## ID ## _l) \
|
|
1923
|
+
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
1887
1924
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
|
|
1888
|
-
if (device->mul_mat ## ID ## _m) \
|
|
1925
|
+
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
1889
1926
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
|
|
1890
|
-
if (device->mul_mat ## ID ## _s) \
|
|
1927
|
+
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
1891
1928
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
|
1892
1929
|
|
|
1893
|
-
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1894
|
-
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1895
|
-
CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1896
|
-
CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1897
|
-
|
|
1898
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1899
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1900
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1901
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1902
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1903
|
-
|
|
1904
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1905
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1906
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1907
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1908
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1909
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1910
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1911
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1912
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1913
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1914
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1915
|
-
CREATE_MM(pipeline_dequant_mul_mat_mat[
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
|
|
1928
|
-
|
|
1929
|
-
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
|
|
1935
|
-
|
|
1936
|
-
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1930
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1931
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1932
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1933
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1934
|
+
|
|
1935
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1936
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1937
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1938
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1939
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1940
|
+
|
|
1941
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1942
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1943
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1944
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1945
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1946
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1947
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1948
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1949
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1950
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1951
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1952
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1953
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1954
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
1955
|
+
|
|
1956
|
+
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1957
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1958
|
+
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1959
|
+
|
|
1960
|
+
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1961
|
+
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1962
|
+
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1963
|
+
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1964
|
+
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1965
|
+
|
|
1966
|
+
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1967
|
+
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1968
|
+
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1969
|
+
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1970
|
+
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1971
|
+
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1972
|
+
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1973
|
+
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1974
|
+
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1975
|
+
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1976
|
+
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1977
|
+
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1978
|
+
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1979
|
+
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
1942
1980
|
#undef CREATE_MM
|
|
1943
1981
|
}
|
|
1944
1982
|
|
|
@@ -1954,6 +1992,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1954
1992
|
}
|
|
1955
1993
|
} else if (device->vendor_id == VK_VENDOR_ID_INTEL)
|
|
1956
1994
|
rm_stdq = 2;
|
|
1995
|
+
uint32_t rm_iq = 2 * rm_kq;
|
|
1957
1996
|
|
|
1958
1997
|
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
|
1959
1998
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
@@ -1968,13 +2007,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1968
2007
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
|
1969
2008
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
|
1970
2009
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
|
1971
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[
|
|
1972
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[
|
|
1973
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[
|
|
1974
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[
|
|
1975
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[
|
|
1976
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[
|
|
1977
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[
|
|
2010
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2011
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2012
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2013
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2014
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2015
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2016
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2017
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2018
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
1978
2019
|
|
|
1979
2020
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
1980
2021
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
@@ -1988,13 +2029,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1988
2029
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
|
1989
2030
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
|
1990
2031
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
|
1991
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[
|
|
1992
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[
|
|
1993
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[
|
|
1994
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[
|
|
1995
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[
|
|
1996
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[
|
|
1997
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[
|
|
2032
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2033
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2034
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2035
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2036
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2037
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2038
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2039
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
2040
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
|
|
1998
2041
|
}
|
|
1999
2042
|
|
|
2000
2043
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
@@ -2009,13 +2052,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2009
2052
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
|
2010
2053
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
|
2011
2054
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
|
2012
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[
|
|
2013
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[
|
|
2014
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[
|
|
2015
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[
|
|
2016
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[
|
|
2017
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[
|
|
2018
|
-
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[
|
|
2055
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2056
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2057
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2058
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2059
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2060
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2061
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2062
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2063
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
2019
2064
|
|
|
2020
2065
|
// dequant shaders
|
|
2021
2066
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
|
@@ -2029,6 +2074,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2029
2074
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
|
2030
2075
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
|
|
2031
2076
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
|
|
2077
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S], "dequant_iq1_s", dequant_iq1_s_len, dequant_iq1_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
|
2078
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M], "dequant_iq1_m", dequant_iq1_m_len, dequant_iq1_m_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
|
2032
2079
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
|
2033
2080
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
|
2034
2081
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
|
@@ -2045,6 +2092,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2045
2092
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2046
2093
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2047
2094
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2095
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2096
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2048
2097
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2049
2098
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2050
2099
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
@@ -2060,6 +2109,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2060
2109
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2061
2110
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2062
2111
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2112
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2113
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2063
2114
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2064
2115
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2065
2116
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
@@ -2076,6 +2127,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2076
2127
|
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2077
2128
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2078
2129
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2130
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2079
2131
|
|
|
2080
2132
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2081
2133
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -2106,6 +2158,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2106
2158
|
|
|
2107
2159
|
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
2108
2160
|
|
|
2161
|
+
ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
|
|
2162
|
+
ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
|
|
2109
2163
|
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
|
|
2110
2164
|
ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
|
|
2111
2165
|
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
|
|
@@ -2128,13 +2182,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2128
2182
|
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2129
2183
|
|
|
2130
2184
|
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2185
|
+
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2131
2186
|
|
|
2132
2187
|
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2133
2188
|
ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2134
2189
|
ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2190
|
+
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2135
2191
|
ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2136
2192
|
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2137
2193
|
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2194
|
+
ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2138
2195
|
|
|
2139
2196
|
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
|
2140
2197
|
|
|
@@ -2142,22 +2199,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2142
2199
|
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
|
2143
2200
|
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
2144
2201
|
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
|
2202
|
+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
2145
2203
|
|
|
2146
2204
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2147
2205
|
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2206
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2207
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2148
2208
|
|
|
2149
2209
|
if (device->float_controls_rte_fp16) {
|
|
2150
2210
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2151
2211
|
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2212
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2213
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2152
2214
|
} else {
|
|
2153
2215
|
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2154
2216
|
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2217
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2218
|
+
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
|
2155
2219
|
}
|
|
2156
2220
|
|
|
2157
2221
|
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
|
|
2158
2222
|
|
|
2223
|
+
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
2224
|
+
|
|
2159
2225
|
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
2160
2226
|
|
|
2227
|
+
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
|
2228
|
+
|
|
2161
2229
|
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
|
2162
2230
|
if (device->float_controls_rte_fp16) {
|
|
2163
2231
|
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
|
@@ -2171,6 +2239,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2171
2239
|
|
|
2172
2240
|
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
|
2173
2241
|
|
|
2242
|
+
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2243
|
+
|
|
2174
2244
|
for (auto &c : compiles) {
|
|
2175
2245
|
c.wait();
|
|
2176
2246
|
}
|
|
@@ -2206,6 +2276,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2206
2276
|
device->physical_device = physical_devices[dev_num];
|
|
2207
2277
|
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
|
2208
2278
|
|
|
2279
|
+
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
|
|
2280
|
+
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
|
|
2281
|
+
|
|
2209
2282
|
bool fp16_storage = false;
|
|
2210
2283
|
bool fp16_compute = false;
|
|
2211
2284
|
bool maintenance4_support = false;
|
|
@@ -2623,34 +2696,36 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2623
2696
|
|
|
2624
2697
|
// Shaders
|
|
2625
2698
|
// Disable matmul tile sizes early if performance low or not supported
|
|
2626
|
-
|
|
2699
|
+
for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
|
|
2700
|
+
switch (device->vendor_id) {
|
|
2627
2701
|
#ifndef GGML_VULKAN_RUN_TESTS
|
|
2628
|
-
|
|
2629
|
-
|
|
2630
|
-
|
|
2631
|
-
|
|
2632
|
-
|
|
2633
|
-
|
|
2634
|
-
|
|
2635
|
-
|
|
2636
|
-
|
|
2637
|
-
|
|
2638
|
-
|
|
2639
|
-
|
|
2640
|
-
|
|
2641
|
-
|
|
2642
|
-
|
|
2643
|
-
|
|
2644
|
-
|
|
2702
|
+
case VK_VENDOR_ID_AMD:
|
|
2703
|
+
case VK_VENDOR_ID_INTEL:
|
|
2704
|
+
device->mul_mat_l[i] = false;
|
|
2705
|
+
device->mul_mat_m[i] = true;
|
|
2706
|
+
device->mul_mat_s[i] = true;
|
|
2707
|
+
device->mul_mat_id_l[i] = false;
|
|
2708
|
+
device->mul_mat_id_m[i] = true;
|
|
2709
|
+
device->mul_mat_id_s[i] = true;
|
|
2710
|
+
break;
|
|
2711
|
+
case VK_VENDOR_ID_APPLE:
|
|
2712
|
+
device->mul_mat_l[i] = false;
|
|
2713
|
+
device->mul_mat_m[i] = true;
|
|
2714
|
+
device->mul_mat_s[i] = false;
|
|
2715
|
+
device->mul_mat_id_l[i] = false;
|
|
2716
|
+
device->mul_mat_id_m[i] = true;
|
|
2717
|
+
device->mul_mat_id_s[i] = false;
|
|
2718
|
+
break;
|
|
2645
2719
|
#endif
|
|
2646
|
-
|
|
2647
|
-
|
|
2648
|
-
|
|
2649
|
-
|
|
2650
|
-
|
|
2651
|
-
|
|
2652
|
-
|
|
2653
|
-
|
|
2720
|
+
default:
|
|
2721
|
+
device->mul_mat_l[i] = true;
|
|
2722
|
+
device->mul_mat_m[i] = true;
|
|
2723
|
+
device->mul_mat_s[i] = true;
|
|
2724
|
+
device->mul_mat_id_l[i] = true;
|
|
2725
|
+
device->mul_mat_id_m[i] = true;
|
|
2726
|
+
device->mul_mat_id_s[i] = true;
|
|
2727
|
+
break;
|
|
2728
|
+
}
|
|
2654
2729
|
}
|
|
2655
2730
|
|
|
2656
2731
|
ggml_vk_load_shaders(device);
|
|
@@ -2780,8 +2855,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
2780
2855
|
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
|
|
2781
2856
|
|
|
2782
2857
|
std::string device_name = props2.properties.deviceName.data();
|
|
2783
|
-
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n",
|
|
2784
|
-
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
|
|
2858
|
+
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
|
|
2859
|
+
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
|
|
2860
|
+
props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
|
|
2785
2861
|
|
|
2786
2862
|
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
|
2787
2863
|
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
|
|
@@ -2791,14 +2867,12 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
2791
2867
|
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
|
2792
2868
|
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
|
2793
2869
|
|
|
2794
|
-
void ggml_vk_instance_init() {
|
|
2870
|
+
static void ggml_vk_instance_init() {
|
|
2795
2871
|
if (vk_instance_initialized) {
|
|
2796
2872
|
return;
|
|
2797
2873
|
}
|
|
2798
2874
|
VK_LOG_DEBUG("ggml_vk_instance_init()");
|
|
2799
2875
|
|
|
2800
|
-
vk_instance_initialized = true;
|
|
2801
|
-
|
|
2802
2876
|
uint32_t api_version = vk::enumerateInstanceVersion();
|
|
2803
2877
|
|
|
2804
2878
|
if (api_version < VK_API_VERSION_1_2) {
|
|
@@ -2849,6 +2923,7 @@ void ggml_vk_instance_init() {
|
|
|
2849
2923
|
GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
|
|
2850
2924
|
}
|
|
2851
2925
|
vk_instance.instance = vk::createInstance(instance_create_info);
|
|
2926
|
+
vk_instance_initialized = true;
|
|
2852
2927
|
|
|
2853
2928
|
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
|
|
2854
2929
|
|
|
@@ -2873,7 +2948,7 @@ void ggml_vk_instance_init() {
|
|
|
2873
2948
|
// Make sure at least one device exists
|
|
2874
2949
|
if (devices.empty()) {
|
|
2875
2950
|
std::cerr << "ggml_vulkan: Error: No devices found." << std::endl;
|
|
2876
|
-
|
|
2951
|
+
return;
|
|
2877
2952
|
}
|
|
2878
2953
|
|
|
2879
2954
|
// Default to using all dedicated GPUs
|
|
@@ -3007,6 +3082,8 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
|
|
3007
3082
|
case GGML_TYPE_Q4_K:
|
|
3008
3083
|
case GGML_TYPE_Q5_K:
|
|
3009
3084
|
case GGML_TYPE_Q6_K:
|
|
3085
|
+
case GGML_TYPE_IQ1_S:
|
|
3086
|
+
case GGML_TYPE_IQ1_M:
|
|
3010
3087
|
case GGML_TYPE_IQ2_XXS:
|
|
3011
3088
|
case GGML_TYPE_IQ2_XS:
|
|
3012
3089
|
case GGML_TYPE_IQ2_S:
|
|
@@ -3061,6 +3138,8 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
|
3061
3138
|
case GGML_TYPE_Q4_K:
|
|
3062
3139
|
case GGML_TYPE_Q5_K:
|
|
3063
3140
|
case GGML_TYPE_Q6_K:
|
|
3141
|
+
case GGML_TYPE_IQ1_S:
|
|
3142
|
+
case GGML_TYPE_IQ1_M:
|
|
3064
3143
|
case GGML_TYPE_IQ2_XXS:
|
|
3065
3144
|
case GGML_TYPE_IQ2_XS:
|
|
3066
3145
|
case GGML_TYPE_IQ2_S:
|
|
@@ -3098,6 +3177,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|
|
3098
3177
|
case GGML_TYPE_Q4_K:
|
|
3099
3178
|
case GGML_TYPE_Q5_K:
|
|
3100
3179
|
case GGML_TYPE_Q6_K:
|
|
3180
|
+
case GGML_TYPE_IQ1_S:
|
|
3181
|
+
case GGML_TYPE_IQ1_M:
|
|
3101
3182
|
case GGML_TYPE_IQ2_XXS:
|
|
3102
3183
|
case GGML_TYPE_IQ2_XS:
|
|
3103
3184
|
case GGML_TYPE_IQ2_S:
|
|
@@ -3147,6 +3228,8 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|
|
3147
3228
|
case GGML_TYPE_Q4_K:
|
|
3148
3229
|
case GGML_TYPE_Q5_K:
|
|
3149
3230
|
case GGML_TYPE_Q6_K:
|
|
3231
|
+
case GGML_TYPE_IQ1_S:
|
|
3232
|
+
case GGML_TYPE_IQ1_M:
|
|
3150
3233
|
case GGML_TYPE_IQ2_XXS:
|
|
3151
3234
|
case GGML_TYPE_IQ2_XS:
|
|
3152
3235
|
case GGML_TYPE_IQ2_S:
|
|
@@ -3179,6 +3262,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
|
|
3179
3262
|
case GGML_TYPE_Q4_K:
|
|
3180
3263
|
case GGML_TYPE_Q5_K:
|
|
3181
3264
|
case GGML_TYPE_Q6_K:
|
|
3265
|
+
case GGML_TYPE_IQ1_S:
|
|
3266
|
+
case GGML_TYPE_IQ1_M:
|
|
3182
3267
|
case GGML_TYPE_IQ2_XXS:
|
|
3183
3268
|
case GGML_TYPE_IQ2_XS:
|
|
3184
3269
|
case GGML_TYPE_IQ2_S:
|
|
@@ -3721,6 +3806,12 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
|
|
|
3721
3806
|
}
|
|
3722
3807
|
}
|
|
3723
3808
|
|
|
3809
|
+
static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
|
|
3810
|
+
VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
|
|
3811
|
+
|
|
3812
|
+
ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
|
|
3813
|
+
}
|
|
3814
|
+
|
|
3724
3815
|
static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
|
|
3725
3816
|
VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
|
|
3726
3817
|
|
|
@@ -3755,31 +3846,31 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
|
|
|
3755
3846
|
return split_k;
|
|
3756
3847
|
}
|
|
3757
3848
|
|
|
3758
|
-
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
|
|
3759
|
-
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
|
|
3849
|
+
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
|
|
3850
|
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
|
3760
3851
|
|
|
3761
3852
|
if (ctx->device->coopmat2) {
|
|
3762
|
-
if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
|
|
3853
|
+
if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
|
|
3763
3854
|
return aligned ? mmp->a_l : mmp->l;
|
|
3764
3855
|
}
|
|
3765
|
-
if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) {
|
|
3856
|
+
if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
|
|
3766
3857
|
return aligned ? mmp->a_m : mmp->m;
|
|
3767
3858
|
}
|
|
3768
3859
|
return aligned ? mmp->a_s : mmp->s;
|
|
3769
3860
|
}
|
|
3770
3861
|
|
|
3771
|
-
if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) {
|
|
3862
|
+
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
|
|
3772
3863
|
return aligned ? mmp->a_s : mmp->s;
|
|
3773
3864
|
}
|
|
3774
|
-
if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) {
|
|
3865
|
+
if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
|
|
3775
3866
|
return aligned ? mmp->a_m : mmp->m;
|
|
3776
3867
|
}
|
|
3777
3868
|
return aligned ? mmp->a_l : mmp->l;
|
|
3778
3869
|
}
|
|
3779
3870
|
|
|
3780
|
-
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
|
|
3781
|
-
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
|
|
3782
|
-
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align;
|
|
3871
|
+
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
|
|
3872
|
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
|
|
3873
|
+
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
|
|
3783
3874
|
}
|
|
3784
3875
|
|
|
3785
3876
|
static void ggml_vk_matmul(
|
|
@@ -3806,31 +3897,31 @@ static void ggml_vk_matmul(
|
|
|
3806
3897
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
|
3807
3898
|
}
|
|
3808
3899
|
|
|
3809
|
-
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
|
|
3810
|
-
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
|
|
3900
|
+
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
|
|
3901
|
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
|
3811
3902
|
|
|
3812
3903
|
if (ctx->device->coopmat2) {
|
|
3813
|
-
if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) {
|
|
3904
|
+
if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
|
|
3814
3905
|
return aligned ? mmp->a_l : mmp->l;
|
|
3815
3906
|
}
|
|
3816
|
-
if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) {
|
|
3907
|
+
if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
|
|
3817
3908
|
return aligned ? mmp->a_m : mmp->m;
|
|
3818
3909
|
}
|
|
3819
3910
|
return aligned ? mmp->a_s : mmp->s;
|
|
3820
3911
|
}
|
|
3821
3912
|
|
|
3822
|
-
if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
|
|
3913
|
+
if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
|
|
3823
3914
|
return aligned ? mmp->a_s : mmp->s;
|
|
3824
3915
|
}
|
|
3825
|
-
if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) {
|
|
3916
|
+
if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
|
|
3826
3917
|
return aligned ? mmp->a_m : mmp->m;
|
|
3827
3918
|
}
|
|
3828
3919
|
return aligned ? mmp->a_l : mmp->l;
|
|
3829
3920
|
}
|
|
3830
3921
|
|
|
3831
|
-
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
|
|
3832
|
-
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
|
|
3833
|
-
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align;
|
|
3922
|
+
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
|
|
3923
|
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
|
|
3924
|
+
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
|
|
3834
3925
|
}
|
|
3835
3926
|
|
|
3836
3927
|
static void ggml_vk_matmul_id(
|
|
@@ -4011,10 +4102,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4011
4102
|
const int y_ne = ne11 * ne10;
|
|
4012
4103
|
const int d_ne = ne11 * ne01;
|
|
4013
4104
|
|
|
4014
|
-
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
|
|
4105
|
+
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
|
4015
4106
|
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
|
4016
4107
|
|
|
4017
|
-
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
|
|
4108
|
+
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
|
4018
4109
|
|
|
4019
4110
|
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
|
|
4020
4111
|
|
|
@@ -4102,7 +4193,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4102
4193
|
}
|
|
4103
4194
|
if (qy_needs_dequant) {
|
|
4104
4195
|
d_Y = ctx->prealloc_y;
|
|
4105
|
-
GGML_ASSERT(d_Y->size >= y_sz *
|
|
4196
|
+
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
|
|
4106
4197
|
} else {
|
|
4107
4198
|
d_Y = d_Qy;
|
|
4108
4199
|
y_buf_offset = qy_buf_offset;
|
|
@@ -4593,10 +4684,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
4593
4684
|
const uint64_t y_ne = ne11 * ne10;
|
|
4594
4685
|
const uint64_t d_ne = ne21 * ne20;
|
|
4595
4686
|
|
|
4596
|
-
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1));
|
|
4687
|
+
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
|
4597
4688
|
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
|
4598
4689
|
|
|
4599
|
-
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned);
|
|
4690
|
+
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
|
4600
4691
|
|
|
4601
4692
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
|
4602
4693
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
|
@@ -4679,7 +4770,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
4679
4770
|
}
|
|
4680
4771
|
if (qy_needs_dequant) {
|
|
4681
4772
|
d_Y = ctx->prealloc_y;
|
|
4682
|
-
GGML_ASSERT(d_Y->size >= y_sz *
|
|
4773
|
+
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
|
|
4683
4774
|
} else {
|
|
4684
4775
|
d_Y = d_Qy;
|
|
4685
4776
|
y_buf_offset = qy_buf_offset;
|
|
@@ -5127,6 +5218,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5127
5218
|
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
|
|
5128
5219
|
}
|
|
5129
5220
|
return nullptr;
|
|
5221
|
+
case GGML_OP_SUB:
|
|
5222
|
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5223
|
+
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32;
|
|
5224
|
+
}
|
|
5225
|
+
return nullptr;
|
|
5130
5226
|
case GGML_OP_MUL:
|
|
5131
5227
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5132
5228
|
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
|
|
@@ -5188,10 +5284,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5188
5284
|
return ctx->device->pipeline_repeat_f32;
|
|
5189
5285
|
}
|
|
5190
5286
|
return nullptr;
|
|
5287
|
+
case GGML_OP_REPEAT_BACK:
|
|
5288
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5289
|
+
return ctx->device->pipeline_repeat_back_f32;
|
|
5290
|
+
}
|
|
5291
|
+
return nullptr;
|
|
5191
5292
|
case GGML_OP_CPY:
|
|
5192
5293
|
case GGML_OP_CONT:
|
|
5193
5294
|
case GGML_OP_DUP:
|
|
5194
5295
|
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
|
|
5296
|
+
case GGML_OP_SILU_BACK:
|
|
5297
|
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5298
|
+
return ctx->device->pipeline_silu_back_f32;
|
|
5299
|
+
}
|
|
5300
|
+
return nullptr;
|
|
5195
5301
|
case GGML_OP_NORM:
|
|
5196
5302
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5197
5303
|
return ctx->device->pipeline_norm_f32;
|
|
@@ -5207,6 +5313,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5207
5313
|
return ctx->device->pipeline_rms_norm_f32;
|
|
5208
5314
|
}
|
|
5209
5315
|
return nullptr;
|
|
5316
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
5317
|
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5318
|
+
return ctx->device->pipeline_rms_norm_back_f32;
|
|
5319
|
+
}
|
|
5320
|
+
return nullptr;
|
|
5210
5321
|
case GGML_OP_UNARY:
|
|
5211
5322
|
switch (ggml_get_unary_op(dst)) {
|
|
5212
5323
|
case GGML_UNARY_OP_SILU:
|
|
@@ -5234,6 +5345,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5234
5345
|
return ctx->device->pipeline_tanh_f32;
|
|
5235
5346
|
}
|
|
5236
5347
|
break;
|
|
5348
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
5349
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5350
|
+
return ctx->device->pipeline_sigmoid_f32;
|
|
5351
|
+
}
|
|
5352
|
+
break;
|
|
5237
5353
|
default:
|
|
5238
5354
|
break;
|
|
5239
5355
|
}
|
|
@@ -5253,10 +5369,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5253
5369
|
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
|
|
5254
5370
|
}
|
|
5255
5371
|
return nullptr;
|
|
5372
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
5373
|
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5374
|
+
return ctx->device->pipeline_soft_max_back_f32;
|
|
5375
|
+
}
|
|
5376
|
+
return nullptr;
|
|
5256
5377
|
case GGML_OP_ROPE:
|
|
5378
|
+
case GGML_OP_ROPE_BACK:
|
|
5257
5379
|
{
|
|
5258
5380
|
const int mode = ((const int32_t *) dst->op_params)[2];
|
|
5259
5381
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
5382
|
+
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
|
5383
|
+
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
5260
5384
|
|
|
5261
5385
|
if (is_neox) {
|
|
5262
5386
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
@@ -5265,6 +5389,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5265
5389
|
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
|
5266
5390
|
return ctx->device->pipeline_rope_neox_f16;
|
|
5267
5391
|
}
|
|
5392
|
+
} else if (is_mrope && !is_vision) {
|
|
5393
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5394
|
+
return ctx->device->pipeline_rope_multi_f32;
|
|
5395
|
+
}
|
|
5396
|
+
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
|
5397
|
+
return ctx->device->pipeline_rope_multi_f16;
|
|
5398
|
+
}
|
|
5399
|
+
} else if (is_vision) {
|
|
5400
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5401
|
+
return ctx->device->pipeline_rope_vision_f32;
|
|
5402
|
+
}
|
|
5403
|
+
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
|
5404
|
+
return ctx->device->pipeline_rope_vision_f16;
|
|
5405
|
+
}
|
|
5268
5406
|
} else {
|
|
5269
5407
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5270
5408
|
return ctx->device->pipeline_rope_norm_f32;
|
|
@@ -5280,11 +5418,22 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5280
5418
|
return ctx->device->pipeline_argsort_f32;
|
|
5281
5419
|
}
|
|
5282
5420
|
return nullptr;
|
|
5421
|
+
case GGML_OP_SUM:
|
|
5283
5422
|
case GGML_OP_SUM_ROWS:
|
|
5284
5423
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5285
5424
|
return ctx->device->pipeline_sum_rows_f32;
|
|
5286
5425
|
}
|
|
5287
5426
|
return nullptr;
|
|
5427
|
+
case GGML_OP_ARGMAX:
|
|
5428
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
|
5429
|
+
return ctx->device->pipeline_argmax_f32;
|
|
5430
|
+
}
|
|
5431
|
+
return nullptr;
|
|
5432
|
+
case GGML_OP_COUNT_EQUAL:
|
|
5433
|
+
if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
|
|
5434
|
+
return ctx->device->pipeline_count_equal_i32;
|
|
5435
|
+
}
|
|
5436
|
+
return nullptr;
|
|
5288
5437
|
case GGML_OP_IM2COL:
|
|
5289
5438
|
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5290
5439
|
return ctx->device->pipeline_im2col_f32;
|
|
@@ -5308,6 +5457,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5308
5457
|
return ctx->device->pipeline_rwkv_wkv6_f32;
|
|
5309
5458
|
}
|
|
5310
5459
|
return nullptr;
|
|
5460
|
+
case GGML_OP_OPT_STEP_ADAMW:
|
|
5461
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5462
|
+
return ctx->device->pipeline_opt_step_adamw_f32;
|
|
5463
|
+
}
|
|
5464
|
+
return nullptr;
|
|
5311
5465
|
case GGML_OP_LEAKY_RELU:
|
|
5312
5466
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5313
5467
|
return ctx->device->pipeline_leaky_relu_f32;
|
|
@@ -5325,6 +5479,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
|
5325
5479
|
case GGML_OP_CPY:
|
|
5326
5480
|
case GGML_OP_GET_ROWS:
|
|
5327
5481
|
case GGML_OP_ADD:
|
|
5482
|
+
case GGML_OP_SUB:
|
|
5328
5483
|
case GGML_OP_MUL:
|
|
5329
5484
|
case GGML_OP_DIV:
|
|
5330
5485
|
case GGML_OP_CONCAT:
|
|
@@ -5335,6 +5490,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
|
5335
5490
|
case GGML_OP_CLAMP:
|
|
5336
5491
|
case GGML_OP_PAD:
|
|
5337
5492
|
case GGML_OP_REPEAT:
|
|
5493
|
+
case GGML_OP_REPEAT_BACK:
|
|
5494
|
+
case GGML_OP_ROPE:
|
|
5338
5495
|
return true;
|
|
5339
5496
|
default:
|
|
5340
5497
|
return false;
|
|
@@ -5546,8 +5703,11 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
5546
5703
|
switch (op) {
|
|
5547
5704
|
case GGML_OP_NORM:
|
|
5548
5705
|
case GGML_OP_RMS_NORM:
|
|
5706
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
5549
5707
|
case GGML_OP_SOFT_MAX:
|
|
5708
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
5550
5709
|
case GGML_OP_SUM_ROWS:
|
|
5710
|
+
case GGML_OP_ARGMAX:
|
|
5551
5711
|
{
|
|
5552
5712
|
const uint32_t nr = ggml_nrows(src0);
|
|
5553
5713
|
if (nr > 262144) {
|
|
@@ -5558,6 +5718,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
5558
5718
|
elements = { nr, 1, 1 };
|
|
5559
5719
|
}
|
|
5560
5720
|
} break;
|
|
5721
|
+
case GGML_OP_SUM:
|
|
5722
|
+
// We use GGML_OP_SUM_ROWS with 1 row.
|
|
5723
|
+
elements = { 1, 1, 1 };
|
|
5724
|
+
break;
|
|
5561
5725
|
case GGML_OP_GROUP_NORM:
|
|
5562
5726
|
{
|
|
5563
5727
|
const uint32_t num_groups = dst->op_params[0];
|
|
@@ -5565,6 +5729,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
5565
5729
|
} break;
|
|
5566
5730
|
case GGML_OP_DIAG_MASK_INF:
|
|
5567
5731
|
case GGML_OP_ROPE:
|
|
5732
|
+
case GGML_OP_ROPE_BACK:
|
|
5568
5733
|
elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
|
|
5569
5734
|
break;
|
|
5570
5735
|
case GGML_OP_GET_ROWS:
|
|
@@ -5604,6 +5769,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
5604
5769
|
elements = { N * OC * OH * OW, 1, 1};
|
|
5605
5770
|
} break;
|
|
5606
5771
|
case GGML_OP_ADD:
|
|
5772
|
+
case GGML_OP_SUB:
|
|
5607
5773
|
case GGML_OP_DIV:
|
|
5608
5774
|
case GGML_OP_MUL:
|
|
5609
5775
|
case GGML_OP_SCALE:
|
|
@@ -5613,6 +5779,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
5613
5779
|
case GGML_OP_CLAMP:
|
|
5614
5780
|
case GGML_OP_PAD:
|
|
5615
5781
|
case GGML_OP_REPEAT:
|
|
5782
|
+
case GGML_OP_REPEAT_BACK:
|
|
5616
5783
|
case GGML_OP_CPY:
|
|
5617
5784
|
case GGML_OP_CONCAT:
|
|
5618
5785
|
case GGML_OP_UPSCALE:
|
|
@@ -5658,7 +5825,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
5658
5825
|
|
|
5659
5826
|
ggml_vk_sync_buffers(subctx);
|
|
5660
5827
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
|
5661
|
-
} else if (op == GGML_OP_ROPE) {
|
|
5828
|
+
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
|
|
5662
5829
|
// Empty src2 is possible in rope, but the shader needs a buffer
|
|
5663
5830
|
vk_subbuffer subbuf_z;
|
|
5664
5831
|
if (use_src2) {
|
|
@@ -5673,6 +5840,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
5673
5840
|
// im2col uses only src1 and dst buffers
|
|
5674
5841
|
ggml_vk_sync_buffers(subctx);
|
|
5675
5842
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
|
5843
|
+
} else if (op == GGML_OP_COUNT_EQUAL) {
|
|
5844
|
+
ggml_vk_sync_buffers(subctx);
|
|
5845
|
+
// count_equal assumes that destination buffer is initialized with zeroes
|
|
5846
|
+
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
|
|
5847
|
+
ggml_vk_sync_buffers(subctx);
|
|
5848
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
|
5676
5849
|
} else if (use_src2) {
|
|
5677
5850
|
ggml_vk_sync_buffers(subctx);
|
|
5678
5851
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
|
|
@@ -5735,6 +5908,21 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
5735
5908
|
}, dryrun);
|
|
5736
5909
|
}
|
|
5737
5910
|
|
|
5911
|
+
static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
5912
|
+
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
5913
|
+
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
5914
|
+
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
5915
|
+
|
|
5916
|
+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, {
|
|
5917
|
+
(uint32_t)ggml_nelements(src0),
|
|
5918
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
5919
|
+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
|
5920
|
+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
5921
|
+
0,
|
|
5922
|
+
0.0f, 0.0f, 0,
|
|
5923
|
+
}, dryrun);
|
|
5924
|
+
}
|
|
5925
|
+
|
|
5738
5926
|
static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
5739
5927
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
5740
5928
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
@@ -5893,6 +6081,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
5893
6081
|
);
|
|
5894
6082
|
}
|
|
5895
6083
|
|
|
6084
|
+
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
|
|
6085
|
+
const ggml_tensor * x = dst->src[0];
|
|
6086
|
+
const ggml_tensor * g = dst->src[1];
|
|
6087
|
+
const ggml_tensor * gm = dst->src[2];
|
|
6088
|
+
const ggml_tensor * gv = dst->src[3];
|
|
6089
|
+
const ggml_tensor * p = dst->src[4];
|
|
6090
|
+
|
|
6091
|
+
GGML_ASSERT(x->type == GGML_TYPE_F32);
|
|
6092
|
+
GGML_ASSERT(g->type == GGML_TYPE_F32);
|
|
6093
|
+
GGML_ASSERT(gm->type == GGML_TYPE_F32);
|
|
6094
|
+
GGML_ASSERT(gv->type == GGML_TYPE_F32);
|
|
6095
|
+
GGML_ASSERT(p->type == GGML_TYPE_F32);
|
|
6096
|
+
GGML_ASSERT(dst->buffer != nullptr);
|
|
6097
|
+
GGML_ASSERT(ggml_is_contiguous(x));
|
|
6098
|
+
GGML_ASSERT(ggml_is_contiguous(g));
|
|
6099
|
+
GGML_ASSERT(ggml_is_contiguous(gm));
|
|
6100
|
+
GGML_ASSERT(ggml_is_contiguous(gv));
|
|
6101
|
+
GGML_ASSERT(ggml_is_contiguous(p));
|
|
6102
|
+
GGML_ASSERT(ggml_are_same_shape(x, g));
|
|
6103
|
+
GGML_ASSERT(ggml_are_same_shape(x, gm));
|
|
6104
|
+
GGML_ASSERT(ggml_are_same_shape(x, gv));
|
|
6105
|
+
GGML_ASSERT(ggml_nelements(p) == 7);
|
|
6106
|
+
|
|
6107
|
+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
|
|
6108
|
+
GGML_ASSERT(pipeline != nullptr);
|
|
6109
|
+
|
|
6110
|
+
if (dryrun) {
|
|
6111
|
+
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
|
6112
|
+
return;
|
|
6113
|
+
}
|
|
6114
|
+
|
|
6115
|
+
ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context;
|
|
6116
|
+
ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context;
|
|
6117
|
+
ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context;
|
|
6118
|
+
ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
|
|
6119
|
+
ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
|
|
6120
|
+
|
|
6121
|
+
ggml_vk_sync_buffers(subctx);
|
|
6122
|
+
|
|
6123
|
+
vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
|
|
6124
|
+
size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
|
|
6125
|
+
bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
|
|
6126
|
+
|
|
6127
|
+
if (ctx->device->uma) {
|
|
6128
|
+
ggml_vk_host_get(ctx->device, x->data, d_X, x_offset);
|
|
6129
|
+
ggml_vk_host_get(ctx->device, g->data, d_G, g_offset);
|
|
6130
|
+
ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset);
|
|
6131
|
+
ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset);
|
|
6132
|
+
ggml_vk_host_get(ctx->device, p->data, d_P, p_offset);
|
|
6133
|
+
|
|
6134
|
+
X_uma = d_X != nullptr;
|
|
6135
|
+
G_uma = d_G != nullptr;
|
|
6136
|
+
GM_uma = d_GM != nullptr;
|
|
6137
|
+
GV_uma = d_GV != nullptr;
|
|
6138
|
+
P_uma = d_P != nullptr;
|
|
6139
|
+
}
|
|
6140
|
+
|
|
6141
|
+
if (!X_uma) {
|
|
6142
|
+
d_X = x_buf_ctx->dev_buffer;
|
|
6143
|
+
x_offset = vk_tensor_offset(x) + x->view_offs;
|
|
6144
|
+
}
|
|
6145
|
+
if (!G_uma) {
|
|
6146
|
+
d_G = g_buf_ctx->dev_buffer;
|
|
6147
|
+
g_offset = vk_tensor_offset(g) + g->view_offs;
|
|
6148
|
+
}
|
|
6149
|
+
if (!GM_uma) {
|
|
6150
|
+
d_GM = gm_buf_ctx->dev_buffer;
|
|
6151
|
+
gm_offset = vk_tensor_offset(gm) + gm->view_offs;
|
|
6152
|
+
}
|
|
6153
|
+
if (!GV_uma) {
|
|
6154
|
+
d_GV = gv_buf_ctx->dev_buffer;
|
|
6155
|
+
gv_offset = vk_tensor_offset(gv) + gv->view_offs;
|
|
6156
|
+
}
|
|
6157
|
+
if (!P_uma) {
|
|
6158
|
+
d_P = p_buf_ctx->dev_buffer;
|
|
6159
|
+
p_offset = vk_tensor_offset(p) + p->view_offs;
|
|
6160
|
+
}
|
|
6161
|
+
|
|
6162
|
+
const uint64_t x_size = ggml_nbytes(x);
|
|
6163
|
+
const uint64_t g_size = ggml_nbytes(g);
|
|
6164
|
+
const uint64_t gm_size = ggml_nbytes(gm);
|
|
6165
|
+
const uint64_t gv_size = ggml_nbytes(gv);
|
|
6166
|
+
const uint64_t p_size = ggml_nbytes(p);
|
|
6167
|
+
|
|
6168
|
+
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
|
|
6169
|
+
|
|
6170
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
|
6171
|
+
vk_subbuffer{ d_X, x_offset, x_size },
|
|
6172
|
+
vk_subbuffer{ d_G, g_offset, g_size },
|
|
6173
|
+
vk_subbuffer{ d_GM, gm_offset, gm_size },
|
|
6174
|
+
vk_subbuffer{ d_GV, gv_offset, gv_size },
|
|
6175
|
+
vk_subbuffer{ d_P, p_offset, p_size },
|
|
6176
|
+
}, sizeof(vk_op_push_constants), &pc, elements);
|
|
6177
|
+
}
|
|
6178
|
+
|
|
6179
|
+
static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
|
6180
|
+
const size_t n = ggml_nelements(dst->src[0]);
|
|
6181
|
+
|
|
6182
|
+
ggml_vk_op_f32_opt_step_adamw(
|
|
6183
|
+
ctx, subctx, dst,
|
|
6184
|
+
{ (uint32_t)n, 0, 0.0f, 0.0f },
|
|
6185
|
+
dryrun
|
|
6186
|
+
);
|
|
6187
|
+
}
|
|
6188
|
+
|
|
5896
6189
|
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
5897
6190
|
int * op_params = (int *)dst->op_params;
|
|
5898
6191
|
|
|
@@ -6026,6 +6319,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6026
6319
|
}, dryrun);
|
|
6027
6320
|
}
|
|
6028
6321
|
|
|
6322
|
+
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
6323
|
+
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
6324
|
+
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
6325
|
+
|
|
6326
|
+
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
|
|
6327
|
+
(uint32_t)ggml_nelements(dst),
|
|
6328
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
6329
|
+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
6330
|
+
0,
|
|
6331
|
+
0.0f, 0.0f,
|
|
6332
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
6333
|
+
}, dryrun);
|
|
6334
|
+
}
|
|
6335
|
+
|
|
6029
6336
|
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
6030
6337
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
6031
6338
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
@@ -6040,6 +6347,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
|
|
6040
6347
|
}, dryrun);
|
|
6041
6348
|
}
|
|
6042
6349
|
|
|
6350
|
+
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
6351
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
|
6352
|
+
}
|
|
6353
|
+
|
|
6043
6354
|
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
6044
6355
|
float * op_params = (float *)dst->op_params;
|
|
6045
6356
|
|
|
@@ -6062,6 +6373,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
6062
6373
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
|
6063
6374
|
}
|
|
6064
6375
|
|
|
6376
|
+
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
6377
|
+
float * op_params = (float *)dst->op_params;
|
|
6378
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
|
6379
|
+
}
|
|
6380
|
+
|
|
6065
6381
|
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
6066
6382
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
|
6067
6383
|
}
|
|
@@ -6097,9 +6413,14 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
6097
6413
|
}, dryrun);
|
|
6098
6414
|
}
|
|
6099
6415
|
|
|
6100
|
-
static void
|
|
6416
|
+
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
6417
|
+
float * op_params = (float *)dst->op_params;
|
|
6418
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun);
|
|
6419
|
+
}
|
|
6420
|
+
|
|
6421
|
+
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
|
|
6101
6422
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
6102
|
-
|
|
6423
|
+
const int mode = ((int32_t *) dst->op_params)[2];
|
|
6103
6424
|
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
6104
6425
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
6105
6426
|
const float freq_base = ((float *) dst->op_params)[5];
|
|
@@ -6108,16 +6429,24 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|
|
6108
6429
|
const float attn_factor = ((float *) dst->op_params)[8];
|
|
6109
6430
|
const float beta_fast = ((float *) dst->op_params)[9];
|
|
6110
6431
|
const float beta_slow = ((float *) dst->op_params)[10];
|
|
6432
|
+
int sections[4] {};
|
|
6433
|
+
if (mode & GGML_ROPE_TYPE_MROPE) {
|
|
6434
|
+
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
6435
|
+
}
|
|
6111
6436
|
|
|
6112
6437
|
float corr_dims[2];
|
|
6113
6438
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
6114
6439
|
|
|
6115
6440
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
6116
6441
|
|
|
6442
|
+
uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
|
|
6443
|
+
uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
|
|
6444
|
+
|
|
6117
6445
|
ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
|
|
6118
6446
|
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
|
6119
6447
|
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
|
6120
|
-
src2 != nullptr,
|
|
6448
|
+
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
|
|
6449
|
+
sections[0], sections[1], sections[2], sections[3], backprop
|
|
6121
6450
|
}, dryrun);
|
|
6122
6451
|
}
|
|
6123
6452
|
|
|
@@ -6140,10 +6469,22 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
6140
6469
|
}, dryrun);
|
|
6141
6470
|
}
|
|
6142
6471
|
|
|
6472
|
+
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
6473
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
|
6474
|
+
}
|
|
6475
|
+
|
|
6143
6476
|
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
6144
6477
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
|
|
6145
6478
|
}
|
|
6146
6479
|
|
|
6480
|
+
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
6481
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
|
|
6482
|
+
}
|
|
6483
|
+
|
|
6484
|
+
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
6485
|
+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
|
6486
|
+
}
|
|
6487
|
+
|
|
6147
6488
|
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
6148
6489
|
const int32_t s0 = dst->op_params[0];
|
|
6149
6490
|
const int32_t s1 = dst->op_params[1];
|
|
@@ -7002,15 +7343,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7002
7343
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
7003
7344
|
case GGML_UNARY_OP_RELU:
|
|
7004
7345
|
case GGML_UNARY_OP_TANH:
|
|
7346
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
7005
7347
|
break;
|
|
7006
7348
|
default:
|
|
7007
7349
|
return false;
|
|
7008
7350
|
}
|
|
7009
7351
|
break;
|
|
7010
7352
|
case GGML_OP_REPEAT:
|
|
7353
|
+
case GGML_OP_REPEAT_BACK:
|
|
7011
7354
|
case GGML_OP_GET_ROWS:
|
|
7012
7355
|
case GGML_OP_ADD:
|
|
7013
7356
|
case GGML_OP_ACC:
|
|
7357
|
+
case GGML_OP_SUB:
|
|
7014
7358
|
case GGML_OP_MUL:
|
|
7015
7359
|
case GGML_OP_DIV:
|
|
7016
7360
|
case GGML_OP_CONCAT:
|
|
@@ -7024,22 +7368,30 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7024
7368
|
case GGML_OP_CPY:
|
|
7025
7369
|
case GGML_OP_CONT:
|
|
7026
7370
|
case GGML_OP_DUP:
|
|
7371
|
+
case GGML_OP_SILU_BACK:
|
|
7027
7372
|
case GGML_OP_NORM:
|
|
7028
7373
|
case GGML_OP_GROUP_NORM:
|
|
7029
7374
|
case GGML_OP_RMS_NORM:
|
|
7375
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
7030
7376
|
case GGML_OP_DIAG_MASK_INF:
|
|
7031
7377
|
case GGML_OP_SOFT_MAX:
|
|
7378
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
7032
7379
|
case GGML_OP_ROPE:
|
|
7380
|
+
case GGML_OP_ROPE_BACK:
|
|
7033
7381
|
case GGML_OP_MUL_MAT:
|
|
7034
7382
|
case GGML_OP_MUL_MAT_ID:
|
|
7035
7383
|
case GGML_OP_ARGSORT:
|
|
7384
|
+
case GGML_OP_SUM:
|
|
7036
7385
|
case GGML_OP_SUM_ROWS:
|
|
7386
|
+
case GGML_OP_ARGMAX:
|
|
7387
|
+
case GGML_OP_COUNT_EQUAL:
|
|
7037
7388
|
case GGML_OP_IM2COL:
|
|
7038
7389
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
7039
7390
|
case GGML_OP_POOL_2D:
|
|
7040
7391
|
case GGML_OP_RWKV_WKV6:
|
|
7041
7392
|
case GGML_OP_LEAKY_RELU:
|
|
7042
7393
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
7394
|
+
case GGML_OP_OPT_STEP_ADAMW:
|
|
7043
7395
|
break;
|
|
7044
7396
|
default:
|
|
7045
7397
|
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
|
@@ -7060,9 +7412,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7060
7412
|
} else {
|
|
7061
7413
|
switch (node->op) {
|
|
7062
7414
|
case GGML_OP_REPEAT:
|
|
7415
|
+
case GGML_OP_REPEAT_BACK:
|
|
7063
7416
|
case GGML_OP_ACC:
|
|
7064
7417
|
case GGML_OP_GET_ROWS:
|
|
7065
7418
|
case GGML_OP_ADD:
|
|
7419
|
+
case GGML_OP_SUB:
|
|
7066
7420
|
case GGML_OP_MUL:
|
|
7067
7421
|
case GGML_OP_DIV:
|
|
7068
7422
|
case GGML_OP_CONCAT:
|
|
@@ -7076,15 +7430,22 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7076
7430
|
case GGML_OP_CPY:
|
|
7077
7431
|
case GGML_OP_CONT:
|
|
7078
7432
|
case GGML_OP_DUP:
|
|
7433
|
+
case GGML_OP_SILU_BACK:
|
|
7079
7434
|
case GGML_OP_NORM:
|
|
7080
7435
|
case GGML_OP_GROUP_NORM:
|
|
7081
7436
|
case GGML_OP_RMS_NORM:
|
|
7437
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
7082
7438
|
case GGML_OP_UNARY:
|
|
7083
7439
|
case GGML_OP_DIAG_MASK_INF:
|
|
7084
7440
|
case GGML_OP_SOFT_MAX:
|
|
7441
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
7085
7442
|
case GGML_OP_ROPE:
|
|
7443
|
+
case GGML_OP_ROPE_BACK:
|
|
7086
7444
|
case GGML_OP_ARGSORT:
|
|
7445
|
+
case GGML_OP_SUM:
|
|
7087
7446
|
case GGML_OP_SUM_ROWS:
|
|
7447
|
+
case GGML_OP_ARGMAX:
|
|
7448
|
+
case GGML_OP_COUNT_EQUAL:
|
|
7088
7449
|
case GGML_OP_IM2COL:
|
|
7089
7450
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
7090
7451
|
case GGML_OP_POOL_2D:
|
|
@@ -7105,6 +7466,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7105
7466
|
case GGML_OP_REPEAT:
|
|
7106
7467
|
ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
|
|
7107
7468
|
|
|
7469
|
+
break;
|
|
7470
|
+
case GGML_OP_REPEAT_BACK:
|
|
7471
|
+
ggml_vk_repeat_back(ctx, compute_ctx, src0, node, dryrun);
|
|
7472
|
+
|
|
7108
7473
|
break;
|
|
7109
7474
|
case GGML_OP_ACC:
|
|
7110
7475
|
ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -7117,6 +7482,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7117
7482
|
case GGML_OP_ADD:
|
|
7118
7483
|
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
7119
7484
|
|
|
7485
|
+
break;
|
|
7486
|
+
case GGML_OP_SUB:
|
|
7487
|
+
ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
7488
|
+
|
|
7120
7489
|
break;
|
|
7121
7490
|
case GGML_OP_MUL:
|
|
7122
7491
|
ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -7163,6 +7532,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7163
7532
|
case GGML_OP_DUP:
|
|
7164
7533
|
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
|
|
7165
7534
|
|
|
7535
|
+
break;
|
|
7536
|
+
case GGML_OP_SILU_BACK:
|
|
7537
|
+
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
7538
|
+
|
|
7166
7539
|
break;
|
|
7167
7540
|
case GGML_OP_NORM:
|
|
7168
7541
|
ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
|
|
@@ -7175,6 +7548,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7175
7548
|
case GGML_OP_RMS_NORM:
|
|
7176
7549
|
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
|
|
7177
7550
|
|
|
7551
|
+
break;
|
|
7552
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
7553
|
+
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
7554
|
+
|
|
7178
7555
|
break;
|
|
7179
7556
|
case GGML_OP_UNARY:
|
|
7180
7557
|
switch (ggml_get_unary_op(node)) {
|
|
@@ -7183,6 +7560,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7183
7560
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
7184
7561
|
case GGML_UNARY_OP_RELU:
|
|
7185
7562
|
case GGML_UNARY_OP_TANH:
|
|
7563
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
7186
7564
|
ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
|
|
7187
7565
|
break;
|
|
7188
7566
|
default:
|
|
@@ -7196,18 +7574,38 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7196
7574
|
case GGML_OP_SOFT_MAX:
|
|
7197
7575
|
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
7198
7576
|
|
|
7577
|
+
break;
|
|
7578
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
7579
|
+
ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
7580
|
+
|
|
7199
7581
|
break;
|
|
7200
7582
|
case GGML_OP_ROPE:
|
|
7201
|
-
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
|
7583
|
+
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
|
|
7584
|
+
|
|
7585
|
+
break;
|
|
7586
|
+
case GGML_OP_ROPE_BACK:
|
|
7587
|
+
ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun);
|
|
7202
7588
|
|
|
7203
7589
|
break;
|
|
7204
7590
|
case GGML_OP_ARGSORT:
|
|
7205
7591
|
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
|
|
7206
7592
|
|
|
7593
|
+
break;
|
|
7594
|
+
case GGML_OP_SUM:
|
|
7595
|
+
ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun);
|
|
7596
|
+
|
|
7207
7597
|
break;
|
|
7208
7598
|
case GGML_OP_SUM_ROWS:
|
|
7209
7599
|
ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
|
|
7210
7600
|
|
|
7601
|
+
break;
|
|
7602
|
+
case GGML_OP_ARGMAX:
|
|
7603
|
+
ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
|
|
7604
|
+
|
|
7605
|
+
break;
|
|
7606
|
+
case GGML_OP_COUNT_EQUAL:
|
|
7607
|
+
ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
7608
|
+
|
|
7211
7609
|
break;
|
|
7212
7610
|
case GGML_OP_IM2COL:
|
|
7213
7611
|
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -7242,6 +7640,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7242
7640
|
case GGML_OP_RWKV_WKV6:
|
|
7243
7641
|
ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
|
|
7244
7642
|
|
|
7643
|
+
break;
|
|
7644
|
+
|
|
7645
|
+
case GGML_OP_OPT_STEP_ADAMW:
|
|
7646
|
+
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
|
7647
|
+
|
|
7245
7648
|
break;
|
|
7246
7649
|
default:
|
|
7247
7650
|
return false;
|
|
@@ -7293,6 +7696,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
7293
7696
|
case GGML_OP_ADD:
|
|
7294
7697
|
case GGML_OP_ACC:
|
|
7295
7698
|
case GGML_OP_GET_ROWS:
|
|
7699
|
+
case GGML_OP_SUB:
|
|
7296
7700
|
case GGML_OP_MUL:
|
|
7297
7701
|
case GGML_OP_DIV:
|
|
7298
7702
|
case GGML_OP_CONCAT:
|
|
@@ -7306,25 +7710,34 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
7306
7710
|
case GGML_OP_CPY:
|
|
7307
7711
|
case GGML_OP_CONT:
|
|
7308
7712
|
case GGML_OP_DUP:
|
|
7713
|
+
case GGML_OP_SILU_BACK:
|
|
7309
7714
|
case GGML_OP_NORM:
|
|
7310
7715
|
case GGML_OP_GROUP_NORM:
|
|
7311
7716
|
case GGML_OP_RMS_NORM:
|
|
7717
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
7312
7718
|
case GGML_OP_DIAG_MASK_INF:
|
|
7313
7719
|
case GGML_OP_SOFT_MAX:
|
|
7720
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
7314
7721
|
case GGML_OP_ROPE:
|
|
7722
|
+
case GGML_OP_ROPE_BACK:
|
|
7315
7723
|
case GGML_OP_RESHAPE:
|
|
7316
7724
|
case GGML_OP_VIEW:
|
|
7317
7725
|
case GGML_OP_PERMUTE:
|
|
7318
7726
|
case GGML_OP_TRANSPOSE:
|
|
7319
7727
|
case GGML_OP_NONE:
|
|
7320
7728
|
case GGML_OP_ARGSORT:
|
|
7729
|
+
case GGML_OP_SUM:
|
|
7321
7730
|
case GGML_OP_SUM_ROWS:
|
|
7731
|
+
case GGML_OP_ARGMAX:
|
|
7732
|
+
case GGML_OP_COUNT_EQUAL:
|
|
7322
7733
|
case GGML_OP_IM2COL:
|
|
7323
7734
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
7324
7735
|
case GGML_OP_POOL_2D:
|
|
7325
7736
|
case GGML_OP_RWKV_WKV6:
|
|
7326
7737
|
case GGML_OP_LEAKY_RELU:
|
|
7327
7738
|
case GGML_OP_REPEAT:
|
|
7739
|
+
case GGML_OP_REPEAT_BACK:
|
|
7740
|
+
case GGML_OP_OPT_STEP_ADAMW:
|
|
7328
7741
|
buf = tensor->buffer;
|
|
7329
7742
|
|
|
7330
7743
|
break;
|
|
@@ -7335,6 +7748,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
7335
7748
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
7336
7749
|
case GGML_UNARY_OP_RELU:
|
|
7337
7750
|
case GGML_UNARY_OP_TANH:
|
|
7751
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
7338
7752
|
buf = tensor->buffer;
|
|
7339
7753
|
break;
|
|
7340
7754
|
default:
|
|
@@ -7509,11 +7923,21 @@ static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
|
7509
7923
|
UNUSED(buffer);
|
|
7510
7924
|
}
|
|
7511
7925
|
|
|
7512
|
-
static
|
|
7926
|
+
static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
|
7513
7927
|
VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
|
|
7514
7928
|
if (tensor->view_src != nullptr) {
|
|
7515
7929
|
GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
|
|
7516
7930
|
}
|
|
7931
|
+
return GGML_STATUS_SUCCESS;
|
|
7932
|
+
}
|
|
7933
|
+
|
|
7934
|
+
static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
|
7935
|
+
VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
|
|
7936
|
+
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
|
7937
|
+
vk_buffer buf = buf_ctx->dev_buffer;
|
|
7938
|
+
|
|
7939
|
+
uint32_t val32 = (uint32_t)value * 0x01010101;
|
|
7940
|
+
ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
|
|
7517
7941
|
}
|
|
7518
7942
|
|
|
7519
7943
|
static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
@@ -7560,7 +7984,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
|
|
|
7560
7984
|
/* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
|
|
7561
7985
|
/* .get_base = */ ggml_backend_vk_buffer_get_base,
|
|
7562
7986
|
/* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
|
|
7563
|
-
/* .memset_tensor = */
|
|
7987
|
+
/* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor,
|
|
7564
7988
|
/* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
|
|
7565
7989
|
/* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
|
|
7566
7990
|
/* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
|
|
@@ -8027,7 +8451,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8027
8451
|
case GGML_UNARY_OP_SILU:
|
|
8028
8452
|
case GGML_UNARY_OP_RELU:
|
|
8029
8453
|
case GGML_UNARY_OP_TANH:
|
|
8030
|
-
|
|
8454
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
8455
|
+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
|
8031
8456
|
default:
|
|
8032
8457
|
return false;
|
|
8033
8458
|
}
|
|
@@ -8035,13 +8460,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8035
8460
|
case GGML_OP_MUL_MAT:
|
|
8036
8461
|
case GGML_OP_MUL_MAT_ID:
|
|
8037
8462
|
{
|
|
8463
|
+
ggml_type src0_type = op->src[0]->type;
|
|
8038
8464
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
8039
8465
|
const vk_device& device = ggml_vk_get_device(ctx->device);
|
|
8040
|
-
if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) {
|
|
8466
|
+
if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
|
|
8041
8467
|
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
|
|
8042
8468
|
return false;
|
|
8043
8469
|
}
|
|
8044
|
-
switch (
|
|
8470
|
+
switch (src0_type) {
|
|
8045
8471
|
case GGML_TYPE_F32:
|
|
8046
8472
|
case GGML_TYPE_F16:
|
|
8047
8473
|
case GGML_TYPE_Q4_0:
|
|
@@ -8054,6 +8480,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8054
8480
|
case GGML_TYPE_Q4_K:
|
|
8055
8481
|
case GGML_TYPE_Q5_K:
|
|
8056
8482
|
case GGML_TYPE_Q6_K:
|
|
8483
|
+
case GGML_TYPE_IQ1_S:
|
|
8484
|
+
case GGML_TYPE_IQ1_M:
|
|
8057
8485
|
case GGML_TYPE_IQ2_XXS:
|
|
8058
8486
|
case GGML_TYPE_IQ2_XS:
|
|
8059
8487
|
case GGML_TYPE_IQ2_S:
|
|
@@ -8128,6 +8556,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8128
8556
|
//case GGML_TYPE_Q4_K:
|
|
8129
8557
|
//case GGML_TYPE_Q5_K:
|
|
8130
8558
|
//case GGML_TYPE_Q6_K:
|
|
8559
|
+
//case GGML_TYPE_IQ1_S:
|
|
8560
|
+
//case GGML_TYPE_IQ1_M:
|
|
8131
8561
|
//case GGML_TYPE_IQ2_XXS:
|
|
8132
8562
|
//case GGML_TYPE_IQ2_XS:
|
|
8133
8563
|
//case GGML_TYPE_IQ2_S:
|
|
@@ -8151,6 +8581,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8151
8581
|
case GGML_TYPE_Q5_0:
|
|
8152
8582
|
case GGML_TYPE_Q5_1:
|
|
8153
8583
|
case GGML_TYPE_Q8_0:
|
|
8584
|
+
case GGML_TYPE_IQ1_S:
|
|
8585
|
+
case GGML_TYPE_IQ1_M:
|
|
8154
8586
|
case GGML_TYPE_IQ2_XXS:
|
|
8155
8587
|
case GGML_TYPE_IQ2_XS:
|
|
8156
8588
|
case GGML_TYPE_IQ2_S:
|
|
@@ -8206,17 +8638,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8206
8638
|
} break;
|
|
8207
8639
|
case GGML_OP_REPEAT:
|
|
8208
8640
|
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
|
|
8641
|
+
case GGML_OP_REPEAT_BACK:
|
|
8642
|
+
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
|
8209
8643
|
case GGML_OP_ROPE:
|
|
8210
|
-
|
|
8211
|
-
const int mode = ((const int32_t *) op->op_params)[2];
|
|
8212
|
-
if (mode & GGML_ROPE_TYPE_MROPE) {
|
|
8213
|
-
return false;
|
|
8214
|
-
}
|
|
8215
|
-
if (mode & GGML_ROPE_TYPE_VISION) {
|
|
8216
|
-
return false;
|
|
8217
|
-
}
|
|
8218
|
-
return ggml_is_contiguous(op->src[0]);
|
|
8219
|
-
}
|
|
8644
|
+
case GGML_OP_ROPE_BACK:
|
|
8220
8645
|
case GGML_OP_NONE:
|
|
8221
8646
|
case GGML_OP_RESHAPE:
|
|
8222
8647
|
case GGML_OP_VIEW:
|
|
@@ -8228,26 +8653,35 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8228
8653
|
case GGML_OP_RMS_NORM:
|
|
8229
8654
|
return ggml_is_contiguous(op->src[0]);
|
|
8230
8655
|
case GGML_OP_ADD:
|
|
8231
|
-
case
|
|
8656
|
+
case GGML_OP_SUB:
|
|
8232
8657
|
case GGML_OP_MUL:
|
|
8233
8658
|
case GGML_OP_DIV:
|
|
8234
|
-
case
|
|
8235
|
-
case
|
|
8236
|
-
case GGML_OP_SCALE:
|
|
8659
|
+
case GGML_OP_SILU_BACK:
|
|
8660
|
+
case GGML_OP_RMS_NORM_BACK:
|
|
8237
8661
|
case GGML_OP_SQR:
|
|
8238
8662
|
case GGML_OP_SIN:
|
|
8239
8663
|
case GGML_OP_COS:
|
|
8240
8664
|
case GGML_OP_CLAMP:
|
|
8665
|
+
return op->src[0]->type == GGML_TYPE_F32;
|
|
8666
|
+
case GGML_OP_ACC:
|
|
8667
|
+
case GGML_OP_CONCAT:
|
|
8668
|
+
case GGML_OP_UPSCALE:
|
|
8669
|
+
case GGML_OP_SCALE:
|
|
8241
8670
|
case GGML_OP_PAD:
|
|
8242
8671
|
case GGML_OP_DIAG_MASK_INF:
|
|
8243
8672
|
case GGML_OP_SOFT_MAX:
|
|
8673
|
+
case GGML_OP_SOFT_MAX_BACK:
|
|
8244
8674
|
case GGML_OP_ARGSORT:
|
|
8675
|
+
case GGML_OP_SUM:
|
|
8245
8676
|
case GGML_OP_SUM_ROWS:
|
|
8677
|
+
case GGML_OP_ARGMAX:
|
|
8678
|
+
case GGML_OP_COUNT_EQUAL:
|
|
8246
8679
|
case GGML_OP_IM2COL:
|
|
8247
8680
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
8248
8681
|
case GGML_OP_POOL_2D:
|
|
8249
8682
|
case GGML_OP_RWKV_WKV6:
|
|
8250
8683
|
case GGML_OP_LEAKY_RELU:
|
|
8684
|
+
case GGML_OP_OPT_STEP_ADAMW:
|
|
8251
8685
|
return true;
|
|
8252
8686
|
default:
|
|
8253
8687
|
return false;
|
|
@@ -8347,8 +8781,13 @@ ggml_backend_reg_t ggml_backend_vk_reg() {
|
|
|
8347
8781
|
/* .iface = */ ggml_backend_vk_reg_i,
|
|
8348
8782
|
/* .context = */ nullptr,
|
|
8349
8783
|
};
|
|
8350
|
-
|
|
8351
|
-
|
|
8784
|
+
try {
|
|
8785
|
+
ggml_vk_instance_init();
|
|
8786
|
+
return ®
|
|
8787
|
+
} catch (const vk::SystemError& e) {
|
|
8788
|
+
VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what());
|
|
8789
|
+
return nullptr;
|
|
8790
|
+
}
|
|
8352
8791
|
}
|
|
8353
8792
|
|
|
8354
8793
|
// Extension availability
|
|
@@ -8515,8 +8954,6 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
8515
8954
|
|
|
8516
8955
|
ggml_tensor * src0 = tensor->src[0];
|
|
8517
8956
|
ggml_tensor * src1 = tensor->src[1];
|
|
8518
|
-
ggml_tensor * src2 = tensor->src[2];
|
|
8519
|
-
ggml_tensor * src3 = tensor->src[3];
|
|
8520
8957
|
|
|
8521
8958
|
struct ggml_init_params iparams = {
|
|
8522
8959
|
/*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
|
|
@@ -8526,239 +8963,121 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
8526
8963
|
|
|
8527
8964
|
struct ggml_context * ggml_ctx = ggml_init(iparams);
|
|
8528
8965
|
|
|
8529
|
-
struct ggml_tensor
|
|
8530
|
-
|
|
8531
|
-
|
|
8532
|
-
|
|
8533
|
-
struct ggml_tensor * tensor_clone = nullptr;
|
|
8534
|
-
|
|
8535
|
-
size_t src0_size;
|
|
8536
|
-
size_t src1_size;
|
|
8537
|
-
size_t src2_size;
|
|
8538
|
-
size_t src3_size;
|
|
8539
|
-
|
|
8540
|
-
void * src0_buffer = nullptr;
|
|
8541
|
-
void * src1_buffer = nullptr;
|
|
8542
|
-
void * src2_buffer = nullptr;
|
|
8543
|
-
void * src3_buffer = nullptr;
|
|
8544
|
-
|
|
8545
|
-
if (src0 != nullptr) {
|
|
8546
|
-
src0_clone = ggml_dup_tensor(ggml_ctx, src0);
|
|
8547
|
-
|
|
8548
|
-
src0_size = ggml_nbytes(src0);
|
|
8549
|
-
|
|
8550
|
-
src0_buffer = malloc(src0_size);
|
|
8551
|
-
src0_clone->data = src0_buffer;
|
|
8552
|
-
if (ggml_backend_buffer_is_host(src0->buffer)) {
|
|
8553
|
-
memcpy(src0_clone->data, src0->data, src0_size);
|
|
8554
|
-
memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
|
8555
|
-
} else if (ggml_backend_buffer_is_vk(src0->buffer)) {
|
|
8556
|
-
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
|
8557
|
-
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
|
|
8558
|
-
uint64_t offset = vk_tensor_offset(src0) + src0->view_offs;
|
|
8559
|
-
if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
|
|
8560
|
-
for (int i3 = 0; i3 < src0->ne[3]; i3++) {
|
|
8561
|
-
for (int i2 = 0; i2 < src0->ne[2]; i2++) {
|
|
8562
|
-
const int idx = i3*src0->ne[2] + i2;
|
|
8563
|
-
ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
|
|
8564
|
-
}
|
|
8565
|
-
}
|
|
8566
|
-
|
|
8567
|
-
src0_clone->nb[0] = src0->nb[0];
|
|
8568
|
-
src0_clone->nb[1] = src0->nb[1];
|
|
8569
|
-
for (int i = 2; i < GGML_MAX_DIMS; i++) {
|
|
8570
|
-
src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1];
|
|
8571
|
-
}
|
|
8572
|
-
} else {
|
|
8573
|
-
if (offset + src0_size >= buffer_gpu->size) {
|
|
8574
|
-
src0_size = buffer_gpu->size - offset;
|
|
8575
|
-
}
|
|
8576
|
-
ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size);
|
|
8577
|
-
memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
|
8578
|
-
}
|
|
8579
|
-
} else {
|
|
8580
|
-
GGML_ABORT("fatal error");
|
|
8581
|
-
}
|
|
8966
|
+
std::array<struct ggml_tensor *, 6> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
8967
|
+
std::array<size_t, 6> src_size = {0, 0, 0, 0, 0, 0};
|
|
8968
|
+
std::array<void *, 6> src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
8969
|
+
const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"};
|
|
8582
8970
|
|
|
8583
|
-
|
|
8584
|
-
ggml_vk_print_tensor(src0, "src0");
|
|
8585
|
-
}
|
|
8586
|
-
}
|
|
8587
|
-
if (src1 != nullptr) {
|
|
8588
|
-
src1_clone = ggml_dup_tensor(ggml_ctx, src1);
|
|
8589
|
-
|
|
8590
|
-
src1_size = ggml_nbytes(src1);
|
|
8591
|
-
|
|
8592
|
-
src1_buffer = malloc(src1_size);
|
|
8593
|
-
src1_clone->data = src1_buffer;
|
|
8594
|
-
if (ggml_backend_buffer_is_host(src1->buffer)) {
|
|
8595
|
-
memcpy(src1_clone->data, src1->data, src1_size);
|
|
8596
|
-
memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
|
8597
|
-
} else if (ggml_backend_buffer_is_vk(src1->buffer)) {
|
|
8598
|
-
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
|
|
8599
|
-
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
|
|
8600
|
-
uint64_t offset = vk_tensor_offset(src1) + src1->view_offs;
|
|
8601
|
-
if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
|
|
8602
|
-
for (int i3 = 0; i3 < src1->ne[3]; i3++) {
|
|
8603
|
-
for (int i2 = 0; i2 < src1->ne[2]; i2++) {
|
|
8604
|
-
const int idx = i3*src1->ne[2] + i2;
|
|
8605
|
-
ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
|
|
8606
|
-
}
|
|
8607
|
-
}
|
|
8608
|
-
|
|
8609
|
-
src1_clone->nb[0] = src1->nb[0];
|
|
8610
|
-
src1_clone->nb[1] = src1->nb[1];
|
|
8611
|
-
for (int i = 2; i < GGML_MAX_DIMS; i++) {
|
|
8612
|
-
src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1];
|
|
8613
|
-
}
|
|
8614
|
-
} else {
|
|
8615
|
-
if (offset + src1_size >= buffer_gpu->size) {
|
|
8616
|
-
src1_size = buffer_gpu->size - offset;
|
|
8617
|
-
}
|
|
8618
|
-
ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size);
|
|
8619
|
-
memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
|
8620
|
-
}
|
|
8621
|
-
} else {
|
|
8622
|
-
GGML_ABORT("fatal error");
|
|
8623
|
-
}
|
|
8624
|
-
|
|
8625
|
-
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
|
|
8626
|
-
ggml_vk_print_tensor(src1, "src1");
|
|
8627
|
-
}
|
|
8628
|
-
}
|
|
8629
|
-
if (src2 != nullptr) {
|
|
8630
|
-
src2_clone = ggml_dup_tensor(ggml_ctx, src2);
|
|
8631
|
-
|
|
8632
|
-
src2_size = ggml_nbytes(src2);
|
|
8633
|
-
|
|
8634
|
-
src2_buffer = malloc(src2_size);
|
|
8635
|
-
src2_clone->data = src2_buffer;
|
|
8636
|
-
if (ggml_backend_buffer_is_host(src2->buffer)) {
|
|
8637
|
-
memcpy(src2_clone->data, src2->data, src2_size);
|
|
8638
|
-
memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
|
8639
|
-
} else if (ggml_backend_buffer_is_vk(src2->buffer)) {
|
|
8640
|
-
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context;
|
|
8641
|
-
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
|
|
8642
|
-
uint64_t offset = vk_tensor_offset(src2) + src2->view_offs;
|
|
8643
|
-
if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
|
|
8644
|
-
for (int i3 = 0; i3 < src2->ne[3]; i3++) {
|
|
8645
|
-
for (int i2 = 0; i2 < src2->ne[2]; i2++) {
|
|
8646
|
-
const int idx = i3*src2->ne[2] + i2;
|
|
8647
|
-
ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
|
|
8648
|
-
}
|
|
8649
|
-
}
|
|
8650
|
-
|
|
8651
|
-
src2_clone->nb[0] = src2->nb[0];
|
|
8652
|
-
src2_clone->nb[1] = src2->nb[1];
|
|
8653
|
-
for (int i = 2; i < GGML_MAX_DIMS; i++) {
|
|
8654
|
-
src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1];
|
|
8655
|
-
}
|
|
8656
|
-
} else {
|
|
8657
|
-
if (offset + src2_size >= buffer_gpu->size) {
|
|
8658
|
-
src2_size = buffer_gpu->size - offset;
|
|
8659
|
-
}
|
|
8660
|
-
ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size);
|
|
8661
|
-
memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
|
8662
|
-
}
|
|
8663
|
-
} else {
|
|
8664
|
-
GGML_ABORT("fatal error");
|
|
8665
|
-
}
|
|
8971
|
+
struct ggml_tensor * tensor_clone = nullptr;
|
|
8666
8972
|
|
|
8667
|
-
|
|
8668
|
-
|
|
8973
|
+
for (int i = 0; i < 6; i++) {
|
|
8974
|
+
ggml_tensor * srci = tensor->src[i];
|
|
8975
|
+
if (srci == nullptr) {
|
|
8976
|
+
continue;
|
|
8669
8977
|
}
|
|
8670
|
-
|
|
8671
|
-
|
|
8672
|
-
src3_clone = ggml_dup_tensor(ggml_ctx, src3);
|
|
8978
|
+
ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci);
|
|
8979
|
+
size_t srci_size = ggml_nbytes(srci);
|
|
8673
8980
|
|
|
8674
|
-
|
|
8981
|
+
src_clone[i] = srci_clone;
|
|
8982
|
+
src_size[i] = ggml_nbytes(srci);
|
|
8983
|
+
src_buffer[i] = malloc(srci_size);
|
|
8675
8984
|
|
|
8676
|
-
|
|
8677
|
-
|
|
8678
|
-
|
|
8679
|
-
memcpy(
|
|
8680
|
-
|
|
8681
|
-
|
|
8682
|
-
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context;
|
|
8985
|
+
srci_clone->data = src_buffer[i];
|
|
8986
|
+
if (ggml_backend_buffer_is_host(srci->buffer)) {
|
|
8987
|
+
memcpy(srci_clone->data, srci->data, srci_size);
|
|
8988
|
+
memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
|
8989
|
+
} else if (ggml_backend_buffer_is_vk(srci->buffer)) {
|
|
8990
|
+
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context;
|
|
8683
8991
|
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
|
|
8684
|
-
uint64_t offset = vk_tensor_offset(
|
|
8685
|
-
if (!ggml_is_contiguous(
|
|
8686
|
-
for (int i3 = 0; i3 <
|
|
8687
|
-
for (int i2 = 0; i2 <
|
|
8688
|
-
const int idx = i3*
|
|
8689
|
-
ggml_vk_buffer_read(buffer_gpu, offset + idx *
|
|
8992
|
+
uint64_t offset = vk_tensor_offset(srci) + srci->view_offs;
|
|
8993
|
+
if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) {
|
|
8994
|
+
for (int i3 = 0; i3 < srci->ne[3]; i3++) {
|
|
8995
|
+
for (int i2 = 0; i2 < srci->ne[2]; i2++) {
|
|
8996
|
+
const int idx = i3*srci->ne[2] + i2;
|
|
8997
|
+
ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]);
|
|
8690
8998
|
}
|
|
8691
8999
|
}
|
|
8692
9000
|
|
|
8693
|
-
|
|
8694
|
-
|
|
9001
|
+
srci_clone->nb[0] = srci->nb[0];
|
|
9002
|
+
srci_clone->nb[1] = srci->nb[1];
|
|
8695
9003
|
for (int i = 2; i < GGML_MAX_DIMS; i++) {
|
|
8696
|
-
|
|
9004
|
+
srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1];
|
|
8697
9005
|
}
|
|
8698
9006
|
} else {
|
|
8699
|
-
if (offset +
|
|
8700
|
-
|
|
9007
|
+
if (offset + srci_size >= buffer_gpu->size) {
|
|
9008
|
+
srci_size = buffer_gpu->size - offset;
|
|
8701
9009
|
}
|
|
8702
|
-
ggml_vk_buffer_read(buffer_gpu, offset,
|
|
8703
|
-
memcpy(
|
|
9010
|
+
ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size);
|
|
9011
|
+
memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
|
8704
9012
|
}
|
|
8705
9013
|
} else {
|
|
8706
9014
|
GGML_ABORT("fatal error");
|
|
8707
9015
|
}
|
|
8708
9016
|
|
|
8709
9017
|
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
|
|
8710
|
-
ggml_vk_print_tensor(
|
|
9018
|
+
ggml_vk_print_tensor(srci, srci_name[i]);
|
|
8711
9019
|
}
|
|
8712
9020
|
}
|
|
8713
9021
|
|
|
8714
9022
|
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
|
8715
9023
|
const float *params = (const float *)tensor->op_params;
|
|
8716
|
-
tensor_clone = ggml_flash_attn_ext(ggml_ctx,
|
|
9024
|
+
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
|
|
8717
9025
|
} else if (tensor->op == GGML_OP_MUL_MAT) {
|
|
8718
|
-
tensor_clone = ggml_mul_mat(ggml_ctx,
|
|
9026
|
+
tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
|
|
8719
9027
|
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
|
|
8720
|
-
tensor_clone = ggml_mul_mat_id(ggml_ctx,
|
|
9028
|
+
tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
|
|
9029
|
+
} else if (tensor->op == GGML_OP_SUB) {
|
|
9030
|
+
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
|
|
8721
9031
|
} else if (tensor->op == GGML_OP_MUL) {
|
|
8722
|
-
tensor_clone = ggml_mul(ggml_ctx,
|
|
9032
|
+
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
|
|
8723
9033
|
} else if (tensor->op == GGML_OP_DIV) {
|
|
8724
|
-
tensor_clone = ggml_div(ggml_ctx,
|
|
9034
|
+
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
|
|
8725
9035
|
} else if (tensor->op == GGML_OP_CONCAT) {
|
|
8726
|
-
tensor_clone = ggml_concat(ggml_ctx,
|
|
9036
|
+
tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
|
|
8727
9037
|
} else if (tensor->op == GGML_OP_UPSCALE) {
|
|
8728
|
-
tensor_clone = ggml_upscale_ext(ggml_ctx,
|
|
9038
|
+
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
|
8729
9039
|
} else if (tensor->op == GGML_OP_SCALE) {
|
|
8730
|
-
tensor_clone = ggml_scale(ggml_ctx,
|
|
9040
|
+
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]);
|
|
8731
9041
|
} else if (tensor->op == GGML_OP_SQR) {
|
|
8732
|
-
tensor_clone = ggml_sqr(ggml_ctx,
|
|
9042
|
+
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
|
|
8733
9043
|
} else if (tensor->op == GGML_OP_SIN) {
|
|
8734
|
-
tensor_clone = ggml_sin(ggml_ctx,
|
|
9044
|
+
tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
|
|
8735
9045
|
} else if (tensor->op == GGML_OP_COS) {
|
|
8736
|
-
tensor_clone = ggml_cos(ggml_ctx,
|
|
9046
|
+
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
|
|
8737
9047
|
} else if (tensor->op == GGML_OP_CLAMP) {
|
|
8738
|
-
tensor_clone = ggml_clamp(ggml_ctx,
|
|
9048
|
+
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
|
8739
9049
|
} else if (tensor->op == GGML_OP_PAD) {
|
|
8740
|
-
tensor_clone = ggml_pad(ggml_ctx,
|
|
9050
|
+
tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
|
|
8741
9051
|
} else if (tensor->op == GGML_OP_REPEAT) {
|
|
8742
|
-
tensor_clone = ggml_repeat(ggml_ctx,
|
|
9052
|
+
tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
|
|
9053
|
+
} else if (tensor->op == GGML_OP_REPEAT_BACK) {
|
|
9054
|
+
tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);
|
|
8743
9055
|
} else if (tensor->op == GGML_OP_ADD) {
|
|
8744
|
-
tensor_clone = ggml_add(ggml_ctx,
|
|
9056
|
+
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
|
|
8745
9057
|
} else if (tensor->op == GGML_OP_ACC) {
|
|
8746
|
-
tensor_clone = ggml_acc(ggml_ctx,
|
|
9058
|
+
tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
|
|
8747
9059
|
} else if (tensor->op == GGML_OP_NORM) {
|
|
8748
|
-
tensor_clone = ggml_norm(ggml_ctx,
|
|
9060
|
+
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
|
8749
9061
|
} else if (tensor->op == GGML_OP_GROUP_NORM) {
|
|
8750
|
-
tensor_clone = ggml_group_norm(ggml_ctx,
|
|
9062
|
+
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
|
|
8751
9063
|
} else if (tensor->op == GGML_OP_RMS_NORM) {
|
|
8752
|
-
tensor_clone = ggml_rms_norm(ggml_ctx,
|
|
9064
|
+
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
|
9065
|
+
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
|
|
9066
|
+
const float eps = ((float *) tensor->op_params)[0];
|
|
9067
|
+
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
|
|
9068
|
+
} else if (tensor->op == GGML_OP_SILU_BACK) {
|
|
9069
|
+
tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
|
|
8753
9070
|
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
|
8754
9071
|
if (src1 != nullptr) {
|
|
8755
|
-
tensor_clone = ggml_soft_max_ext(ggml_ctx,
|
|
9072
|
+
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
|
8756
9073
|
} else {
|
|
8757
|
-
tensor_clone = ggml_soft_max(ggml_ctx,
|
|
9074
|
+
tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
|
|
8758
9075
|
}
|
|
9076
|
+
} else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
|
|
9077
|
+
tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
|
8759
9078
|
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
|
|
8760
|
-
tensor_clone = ggml_diag_mask_inf(ggml_ctx,
|
|
8761
|
-
} else if (tensor->op == GGML_OP_ROPE) {
|
|
9079
|
+
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
|
|
9080
|
+
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
|
|
8762
9081
|
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
|
8763
9082
|
const int mode = ((int32_t *) tensor->op_params)[2];
|
|
8764
9083
|
//const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
|
|
@@ -8769,23 +9088,39 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
8769
9088
|
const float attn_factor = ((float *) tensor->op_params)[8];
|
|
8770
9089
|
const float beta_fast = ((float *) tensor->op_params)[9];
|
|
8771
9090
|
const float beta_slow = ((float *) tensor->op_params)[10];
|
|
8772
|
-
|
|
9091
|
+
if (mode & GGML_ROPE_TYPE_MROPE) {
|
|
9092
|
+
int32_t *sections = ((int32_t *) tensor->op_params) + 11;
|
|
9093
|
+
if (tensor->op == GGML_OP_ROPE) {
|
|
9094
|
+
tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
|
9095
|
+
} else {
|
|
9096
|
+
tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
|
9097
|
+
}
|
|
9098
|
+
} else {
|
|
9099
|
+
if (tensor->op == GGML_OP_ROPE) {
|
|
9100
|
+
tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
|
9101
|
+
} else {
|
|
9102
|
+
tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
|
9103
|
+
}
|
|
9104
|
+
}
|
|
8773
9105
|
} else if (tensor->op == GGML_OP_UNARY) {
|
|
8774
9106
|
switch (ggml_get_unary_op(tensor)) {
|
|
8775
9107
|
case GGML_UNARY_OP_SILU:
|
|
8776
|
-
tensor_clone = ggml_silu(ggml_ctx,
|
|
9108
|
+
tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
|
|
8777
9109
|
break;
|
|
8778
9110
|
case GGML_UNARY_OP_GELU:
|
|
8779
|
-
tensor_clone = ggml_gelu(ggml_ctx,
|
|
9111
|
+
tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
|
|
8780
9112
|
break;
|
|
8781
9113
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
8782
|
-
tensor_clone = ggml_gelu_quick(ggml_ctx,
|
|
9114
|
+
tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
|
|
8783
9115
|
break;
|
|
8784
9116
|
case GGML_UNARY_OP_RELU:
|
|
8785
|
-
tensor_clone = ggml_relu(ggml_ctx,
|
|
9117
|
+
tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
|
|
8786
9118
|
break;
|
|
8787
9119
|
case GGML_UNARY_OP_TANH:
|
|
8788
|
-
tensor_clone = ggml_tanh(ggml_ctx,
|
|
9120
|
+
tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
|
|
9121
|
+
break;
|
|
9122
|
+
case GGML_UNARY_OP_SIGMOID:
|
|
9123
|
+
tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
|
|
8789
9124
|
break;
|
|
8790
9125
|
default:
|
|
8791
9126
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
|
@@ -8793,28 +9128,34 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
8793
9128
|
}
|
|
8794
9129
|
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
|
|
8795
9130
|
if (src1 == nullptr) {
|
|
8796
|
-
tensor_clone = ggml_dup(ggml_ctx,
|
|
9131
|
+
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
|
|
8797
9132
|
tensor_clone->type = tensor->type;
|
|
8798
9133
|
} else {
|
|
8799
|
-
tensor_clone = ggml_cpy(ggml_ctx,
|
|
9134
|
+
tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
|
|
8800
9135
|
}
|
|
8801
9136
|
} else if (tensor->op == GGML_OP_CONT) {
|
|
8802
|
-
tensor_clone = ggml_cont_4d(ggml_ctx,
|
|
9137
|
+
tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
|
8803
9138
|
} else if (tensor->op == GGML_OP_RESHAPE) {
|
|
8804
|
-
tensor_clone = ggml_reshape_4d(ggml_ctx,
|
|
9139
|
+
tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
|
8805
9140
|
} else if (tensor->op == GGML_OP_VIEW) {
|
|
8806
|
-
tensor_clone = ggml_view_4d(ggml_ctx,
|
|
9141
|
+
tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
|
|
8807
9142
|
} else if (tensor->op == GGML_OP_PERMUTE) {
|
|
8808
9143
|
int32_t * params = (int32_t *)tensor->op_params;
|
|
8809
|
-
tensor_clone = ggml_permute(ggml_ctx,
|
|
9144
|
+
tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]);
|
|
8810
9145
|
} else if (tensor->op == GGML_OP_TRANSPOSE) {
|
|
8811
|
-
tensor_clone = ggml_transpose(ggml_ctx,
|
|
9146
|
+
tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]);
|
|
8812
9147
|
} else if (tensor->op == GGML_OP_GET_ROWS) {
|
|
8813
|
-
tensor_clone = ggml_get_rows(ggml_ctx,
|
|
9148
|
+
tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
|
|
8814
9149
|
} else if (tensor->op == GGML_OP_ARGSORT) {
|
|
8815
|
-
tensor_clone = ggml_argsort(ggml_ctx,
|
|
9150
|
+
tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
|
|
9151
|
+
} else if (tensor->op == GGML_OP_SUM) {
|
|
9152
|
+
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
|
|
8816
9153
|
} else if (tensor->op == GGML_OP_SUM_ROWS) {
|
|
8817
|
-
tensor_clone = ggml_sum_rows(ggml_ctx,
|
|
9154
|
+
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
|
|
9155
|
+
} else if (tensor->op == GGML_OP_ARGMAX) {
|
|
9156
|
+
tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
|
|
9157
|
+
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {
|
|
9158
|
+
tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
|
|
8818
9159
|
} else if (tensor->op == GGML_OP_IM2COL) {
|
|
8819
9160
|
const int32_t s0 = tensor->op_params[0];
|
|
8820
9161
|
const int32_t s1 = tensor->op_params[1];
|
|
@@ -8824,11 +9165,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
8824
9165
|
const int32_t d1 = tensor->op_params[5];
|
|
8825
9166
|
|
|
8826
9167
|
const bool is_2D = tensor->op_params[6] == 1;
|
|
8827
|
-
tensor_clone = ggml_im2col(ggml_ctx,
|
|
9168
|
+
tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
|
|
8828
9169
|
} else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
|
|
8829
9170
|
const int32_t dim = tensor->op_params[0];
|
|
8830
9171
|
const int32_t max_period = tensor->op_params[1];
|
|
8831
|
-
tensor_clone = ggml_timestep_embedding(ggml_ctx,
|
|
9172
|
+
tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
|
|
8832
9173
|
} else if (tensor->op == GGML_OP_POOL_2D) {
|
|
8833
9174
|
enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
|
|
8834
9175
|
const int32_t k0 = tensor->op_params[1];
|
|
@@ -8838,13 +9179,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
8838
9179
|
const int32_t p0 = tensor->op_params[5];
|
|
8839
9180
|
const int32_t p1 = tensor->op_params[6];
|
|
8840
9181
|
|
|
8841
|
-
tensor_clone = ggml_pool_2d(ggml_ctx,
|
|
9182
|
+
tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
|
|
8842
9183
|
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
|
|
8843
9184
|
const float * op_params = (const float *)tensor->op_params;
|
|
8844
|
-
tensor_clone = ggml_leaky_relu(ggml_ctx,
|
|
9185
|
+
tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
|
|
8845
9186
|
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
|
|
8846
|
-
tensor_clone = ggml_rwkv_wkv6(ggml_ctx,
|
|
8847
|
-
|
|
9187
|
+
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
|
|
9188
|
+
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
|
|
9189
|
+
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
|
9190
|
+
src_clone[0]->flags = src0->flags;
|
|
9191
|
+
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
|
9192
|
+
src_clone[2], src_clone[3], src_clone[4]);
|
|
8848
9193
|
}
|
|
8849
9194
|
else {
|
|
8850
9195
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
|
@@ -8866,11 +9211,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
8866
9211
|
memcpy(comp_result, tensor_clone->data, comp_size);
|
|
8867
9212
|
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
|
8868
9213
|
|
|
8869
|
-
|
|
8870
|
-
|
|
8871
|
-
|
|
8872
|
-
|
|
8873
|
-
free(src1_buffer);
|
|
9214
|
+
for (int i = 0; i < 6; i++) {
|
|
9215
|
+
if (src_buffer[i] != nullptr) {
|
|
9216
|
+
free(src_buffer[i]);
|
|
9217
|
+
}
|
|
8874
9218
|
}
|
|
8875
9219
|
|
|
8876
9220
|
ggml_free(ggml_ctx);
|
|
@@ -8934,6 +9278,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|
|
8934
9278
|
} else if (tensor->type == GGML_TYPE_I32) {
|
|
8935
9279
|
correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
|
|
8936
9280
|
result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
|
|
9281
|
+
} else if (tensor->type == GGML_TYPE_I64) {
|
|
9282
|
+
correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
|
|
9283
|
+
result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
|
|
8937
9284
|
} else {
|
|
8938
9285
|
std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
|
|
8939
9286
|
}
|