@fugood/llama.node 0.3.12 → 0.3.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +1 -0
  18. package/package.json +1 -1
  19. package/src/LlamaCompletionWorker.cpp +14 -0
  20. package/src/LlamaContext.cpp +13 -4
  21. package/src/llama.cpp/.github/workflows/build.yml +35 -3
  22. package/src/llama.cpp/.github/workflows/docker.yml +2 -0
  23. package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
  24. package/src/llama.cpp/common/CMakeLists.txt +20 -3
  25. package/src/llama.cpp/common/arg.cpp +180 -3
  26. package/src/llama.cpp/common/chat-template.hpp +21 -7
  27. package/src/llama.cpp/common/chat.cpp +220 -101
  28. package/src/llama.cpp/common/chat.hpp +3 -0
  29. package/src/llama.cpp/common/common.h +15 -7
  30. package/src/llama.cpp/common/llguidance.cpp +3 -3
  31. package/src/llama.cpp/common/log.cpp +1 -0
  32. package/src/llama.cpp/common/log.h +2 -1
  33. package/src/llama.cpp/common/minja.hpp +24 -9
  34. package/src/llama.cpp/common/sampling.cpp +52 -46
  35. package/src/llama.cpp/common/speculative.h +1 -1
  36. package/src/llama.cpp/docs/build.md +2 -2
  37. package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -1
  38. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
  39. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
  40. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
  41. package/src/llama.cpp/examples/run/run.cpp +5 -12
  42. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
  43. package/src/llama.cpp/examples/server/httplib.h +381 -292
  44. package/src/llama.cpp/examples/server/server.cpp +58 -47
  45. package/src/llama.cpp/examples/server/utils.hpp +7 -5
  46. package/src/llama.cpp/ggml/include/ggml-cpu.h +1 -1
  47. package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
  48. package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
  49. package/src/llama.cpp/ggml/include/ggml.h +1 -1
  50. package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
  51. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +6 -12
  52. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +852 -268
  53. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +200 -107
  54. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
  56. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +2 -2
  57. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +26 -4
  58. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +6 -7
  59. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +812 -569
  60. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +25 -1
  61. package/src/llama.cpp/ggml/src/ggml.c +1 -1
  62. package/src/llama.cpp/include/llama.h +14 -10
  63. package/src/llama.cpp/src/llama-grammar.cpp +1 -1
  64. package/src/llama.cpp/src/llama-grammar.h +1 -1
  65. package/src/llama.cpp/src/llama-impl.h +6 -6
  66. package/src/llama.cpp/src/llama-kv-cache.h +1 -1
  67. package/src/llama.cpp/src/llama-mmap.h +1 -0
  68. package/src/llama.cpp/src/llama-model.cpp +1 -1
  69. package/src/llama.cpp/src/llama-sampling.cpp +131 -57
  70. package/src/llama.cpp/src/llama.cpp +7 -5
  71. package/src/llama.cpp/src/unicode.cpp +9 -2
  72. package/src/llama.cpp/tests/test-backend-ops.cpp +5 -5
  73. package/src/llama.cpp/tests/test-chat.cpp +237 -69
  74. package/src/llama.cpp/tests/test-gguf.cpp +4 -4
  75. package/src/llama.cpp/tests/test-sampling.cpp +15 -0
@@ -167,6 +167,7 @@ struct vk_device_struct {
167
167
  uint32_t subgroup_size;
168
168
  uint32_t shader_core_count;
169
169
  bool uma;
170
+ bool prefer_host_memory;
170
171
  bool float_controls_rte_fp16;
171
172
 
172
173
  bool subgroup_size_control;
@@ -184,12 +185,12 @@ struct vk_device_struct {
184
185
 
185
186
  size_t idx;
186
187
 
187
- bool mul_mat_l;
188
- bool mul_mat_m;
189
- bool mul_mat_s;
190
- bool mul_mat_id_l;
191
- bool mul_mat_id_m;
192
- bool mul_mat_id_s;
188
+ bool mul_mat_l[GGML_TYPE_COUNT];
189
+ bool mul_mat_m[GGML_TYPE_COUNT];
190
+ bool mul_mat_s[GGML_TYPE_COUNT];
191
+ bool mul_mat_id_l[GGML_TYPE_COUNT];
192
+ bool mul_mat_id_m[GGML_TYPE_COUNT];
193
+ bool mul_mat_id_s[GGML_TYPE_COUNT];
193
194
 
194
195
  // set to true to indicate that some shaders need to be compiled after the dryrun
195
196
  bool need_compiles {};
@@ -221,6 +222,7 @@ struct vk_device_struct {
221
222
  vk_pipeline pipeline_acc_f32;
222
223
  vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
223
224
  vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
225
+ vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
224
226
  vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
225
227
  vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
226
228
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
@@ -231,7 +233,7 @@ struct vk_device_struct {
231
233
  vk_pipeline pipeline_cos_f32;
232
234
  vk_pipeline pipeline_clamp_f32;
233
235
  vk_pipeline pipeline_pad_f32;
234
- vk_pipeline pipeline_repeat_f32;
236
+ vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
235
237
  vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
236
238
  vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
237
239
  vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
@@ -250,12 +252,17 @@ struct vk_device_struct {
250
252
  vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
251
253
  vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
252
254
  vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
255
+ vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
256
+ vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
253
257
  vk_pipeline pipeline_argsort_f32;
254
258
  vk_pipeline pipeline_sum_rows_f32;
259
+ vk_pipeline pipeline_argmax_f32;
260
+ vk_pipeline pipeline_count_equal_i32;
255
261
  vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
256
262
  vk_pipeline pipeline_timestep_embedding_f32;
257
263
  vk_pipeline pipeline_pool2d_f32;
258
264
  vk_pipeline pipeline_rwkv_wkv6_f32;
265
+ vk_pipeline pipeline_opt_step_adamw_f32;
259
266
 
260
267
  // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
261
268
  vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
@@ -493,6 +500,10 @@ struct vk_op_rope_push_constants {
493
500
  float corr_dims[2];
494
501
  float theta_scale;
495
502
  uint32_t has_ff;
503
+ uint32_t ne02;
504
+ uint32_t s1;
505
+ uint32_t s2;
506
+ int32_t sections[4];
496
507
  };
497
508
 
498
509
  struct vk_op_soft_max_push_constants {
@@ -1294,7 +1305,9 @@ static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk:
1294
1305
  static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
1295
1306
  vk_buffer buf;
1296
1307
  try {
1297
- if (device->uma) {
1308
+ if (device->prefer_host_memory) {
1309
+ buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
1310
+ } else if (device->uma) {
1298
1311
  // Fall back to host memory type
1299
1312
  buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
1300
1313
  } else {
@@ -1378,7 +1391,37 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
1378
1391
  return {64, 64};
1379
1392
  };
1380
1393
 
1381
- static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id) {
1394
+ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
1395
+
1396
+ uint32_t lut_size = 0;
1397
+ switch (src0_type) {
1398
+ case GGML_TYPE_IQ1_S:
1399
+ case GGML_TYPE_IQ1_M:
1400
+ lut_size = 2*2048;
1401
+ break;
1402
+ case GGML_TYPE_IQ2_XXS:
1403
+ lut_size = 8*256;
1404
+ break;
1405
+ case GGML_TYPE_IQ2_XS:
1406
+ lut_size = 8*512;
1407
+ break;
1408
+ case GGML_TYPE_IQ2_S:
1409
+ lut_size = 8*1024;
1410
+ break;
1411
+ case GGML_TYPE_IQ3_XXS:
1412
+ lut_size = 4*256;
1413
+ break;
1414
+ case GGML_TYPE_IQ3_S:
1415
+ lut_size = 4*512;
1416
+ break;
1417
+ case GGML_TYPE_IQ4_NL:
1418
+ case GGML_TYPE_IQ4_XS:
1419
+ lut_size = 4*16;
1420
+ break;
1421
+ default:
1422
+ break;
1423
+ }
1424
+
1382
1425
  // Needs to be kept up to date on shader changes
1383
1426
  const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
1384
1427
  const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
@@ -1388,13 +1431,20 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1388
1431
  const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
1389
1432
  const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
1390
1433
 
1391
- return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize;
1434
+ const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
1435
+ const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
1436
+
1437
+ VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
1438
+ "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported);
1439
+
1440
+ return supported;
1392
1441
  }
1393
1442
 
1394
1443
  static void ggml_vk_load_shaders(vk_device& device) {
1395
1444
  VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
1396
1445
 
1397
1446
  // some shaders have a minimum subgroup size
1447
+ const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
1398
1448
  const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
1399
1449
  const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
1400
1450
 
@@ -1457,13 +1507,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
1457
1507
  const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
1458
1508
  const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
1459
1509
 
1460
- l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
1461
- m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
1462
- s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
1510
+ l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
1511
+ m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
1512
+ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
1463
1513
 
1464
- l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
1465
- m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
1466
- s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
1514
+ l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
1515
+ m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
1516
+ s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
1467
1517
 
1468
1518
  l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
1469
1519
  m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
@@ -1472,62 +1522,32 @@ static void ggml_vk_load_shaders(vk_device& device) {
1472
1522
  m_align = 64;
1473
1523
  s_align = 32;
1474
1524
 
1475
- // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
1476
- // and tile sizes, this should handle 16KB, 32KB, and 48KB+.
1477
- // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
1478
- // But the numbers happen to work out for 32KB shared memory size that when using the medium
1479
- // size there's enough room for everything, and we assert for this.
1480
- uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
1481
- if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
1482
- l_warptile = m_warptile;
1483
- l_wg_denoms = m_wg_denoms;
1484
- shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
1485
- GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
1486
- }
1487
- if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1488
- // assert mul_mat_mat_id shaders will fit.
1489
- GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
1490
- }
1491
-
1492
- shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
1493
- if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
1494
- if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
1495
- l_warptile_mmq = m_warptile_mmq;
1496
- l_mmq_wg_denoms = m_mmq_wg_denoms;
1497
- } else {
1498
- l_warptile_mmq = s_warptile_mmq;
1499
- l_mmq_wg_denoms = s_mmq_wg_denoms;
1525
+ for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
1526
+ ggml_type t = (ggml_type)i;
1527
+ // Disable medium and large matrix multiplication if not enough shared memory is available
1528
+ // Check mmq warptiles as the largest configuration
1529
+ // Throw an error if not enough for any matrix multiplication is available
1530
+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) {
1531
+ std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
1532
+ throw std::runtime_error("Shared memory size too small for matrix multiplication.");
1533
+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) {
1534
+ device->mul_mat_m[i] = false;
1535
+ device->mul_mat_l[i] = false;
1536
+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) {
1537
+ device->mul_mat_l[i] = false;
1538
+ }
1539
+
1540
+ // Disable mul_mat_id if not enough shared memory is available
1541
+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) {
1542
+ device->mul_mat_id_s[i] = false;
1543
+ device->mul_mat_id_m[i] = false;
1544
+ device->mul_mat_id_l[i] = false;
1545
+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) {
1546
+ device->mul_mat_id_m[i] = false;
1547
+ device->mul_mat_id_l[i] = false;
1548
+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) {
1549
+ device->mul_mat_id_l[i] = false;
1500
1550
  }
1501
- shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
1502
- GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
1503
- }
1504
- if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1505
- // assert mul_mat_mat_id shaders will fit.
1506
- GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
1507
- }
1508
- // Disable medium and large matrix multiplication if not enough shared memory is available
1509
- // Check mmq warptiles as the largest configuration
1510
- // Throw an error if not enough for any matrix multiplication is available
1511
- if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) {
1512
- std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
1513
- throw std::runtime_error("Shared memory size too small for matrix multiplication.");
1514
- } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) {
1515
- device->mul_mat_m = false;
1516
- device->mul_mat_l = false;
1517
- } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) {
1518
- device->mul_mat_l = false;
1519
- }
1520
-
1521
- // Disable mul_mat_id if not enough shared memory is available
1522
- if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) {
1523
- device->mul_mat_id_s = false;
1524
- device->mul_mat_id_m = false;
1525
- device->mul_mat_id_l = false;
1526
- } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) {
1527
- device->mul_mat_id_m = false;
1528
- device->mul_mat_id_l = false;
1529
- } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) {
1530
- device->mul_mat_id_l = false;
1531
1551
  }
1532
1552
  }
1533
1553
 
@@ -1617,6 +1637,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1617
1637
  //CREATE_FA(GGML_TYPE_Q4_K, q4_k)
1618
1638
  //CREATE_FA(GGML_TYPE_Q5_K, q5_k)
1619
1639
  //CREATE_FA(GGML_TYPE_Q6_K, q6_k)
1640
+ //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s)
1641
+ //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m)
1620
1642
  //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs)
1621
1643
  //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs)
1622
1644
  //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
@@ -1651,6 +1673,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1651
1673
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1652
1674
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1653
1675
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1676
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1677
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1654
1678
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1655
1679
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1656
1680
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1670,6 +1694,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1670
1694
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1671
1695
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1672
1696
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1697
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1698
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1673
1699
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1674
1700
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1675
1701
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@@ -1684,119 +1710,124 @@ static void ggml_vk_load_shaders(vk_device& device) {
1684
1710
  #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
1685
1711
  if (device->coopmat_support) {
1686
1712
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1687
- #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1688
- if (device->mul_mat ## ID ## _l) \
1713
+ #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1714
+ if (device->mul_mat ## ID ## _l[TYPE]) \
1689
1715
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
1690
- if (device->mul_mat ## ID ## _m) \
1716
+ if (device->mul_mat ## ID ## _m[TYPE]) \
1691
1717
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
1692
- if (device->mul_mat ## ID ## _s) \
1718
+ if (device->mul_mat ## ID ## _s[TYPE]) \
1693
1719
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
1694
- if (device->mul_mat ## ID ## _l) \
1720
+ if (device->mul_mat ## ID ## _l[TYPE]) \
1695
1721
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
1696
- if (device->mul_mat ## ID ## _m) \
1722
+ if (device->mul_mat ## ID ## _m[TYPE]) \
1697
1723
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
1698
- if (device->mul_mat ## ID ## _s) \
1724
+ if (device->mul_mat ## ID ## _s[TYPE]) \
1699
1725
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
1700
1726
 
1701
1727
  // Create 2 variants, {f16,f32} accumulator
1702
- #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1728
+ #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1703
1729
  if (device->coopmat_acc_f16_support) { \
1704
- CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1730
+ CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1705
1731
  } \
1706
1732
  if (device->coopmat_acc_f32_support) { \
1707
- CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1733
+ CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1708
1734
  } \
1709
1735
 
1710
- CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1711
- CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1712
- CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1713
- CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1736
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1737
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1738
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1739
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1714
1740
 
1715
1741
  if (device->coopmat_acc_f16_support) {
1716
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1717
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1718
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1719
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1720
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1721
-
1722
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1723
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1724
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1725
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1726
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1727
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1728
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1729
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1730
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1731
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1732
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1733
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1742
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1743
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1744
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1745
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1746
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1747
+
1748
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1749
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1750
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1751
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1752
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1753
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1754
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1755
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1756
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1757
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1758
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1759
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1760
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1761
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1734
1762
  } else {
1735
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1736
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1737
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1738
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1739
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1740
-
1741
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1742
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1743
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1744
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1745
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1746
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1747
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1748
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1749
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1750
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1751
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1752
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1753
- }
1754
-
1755
- // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1756
- if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1757
- CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1758
- CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1759
- CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1760
-
1761
- if (device->coopmat_acc_f16_support) {
1762
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1763
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1764
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1765
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1766
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1767
-
1768
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1769
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1770
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1771
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1772
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1773
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1774
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1775
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1776
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1777
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1778
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1779
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1780
- } else {
1781
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1782
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1783
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1784
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1785
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1786
-
1787
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1788
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1789
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1790
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1791
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1792
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1793
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1794
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1795
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1796
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1797
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1798
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1799
- }
1763
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1764
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1765
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1766
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1767
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1768
+
1769
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1770
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1771
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1772
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1773
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1774
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1775
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1776
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1777
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1778
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1779
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1780
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1781
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1782
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1783
+ }
1784
+
1785
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1786
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1787
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1788
+
1789
+ if (device->coopmat_acc_f16_support) {
1790
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1791
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1792
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1793
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1794
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1795
+
1796
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1797
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1798
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1799
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1800
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1801
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1802
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1803
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1804
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1805
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1806
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1807
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1808
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1809
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1810
+ } else {
1811
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1812
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1813
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1814
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1815
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1816
+
1817
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1818
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1819
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1820
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1821
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1822
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1823
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1824
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1825
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1826
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1827
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1828
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1829
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1830
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1800
1831
  }
1801
1832
  #undef CREATE_MM2
1802
1833
  #undef CREATE_MM
@@ -1804,141 +1835,143 @@ static void ggml_vk_load_shaders(vk_device& device) {
1804
1835
  #endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
1805
1836
  if (device->fp16) {
1806
1837
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1807
- #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1808
- if (device->mul_mat ## ID ## _l) \
1838
+ #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1839
+ if (device->mul_mat ## ID ## _l[TYPE]) \
1809
1840
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1810
- if (device->mul_mat ## ID ## _m) \
1841
+ if (device->mul_mat ## ID ## _m[TYPE]) \
1811
1842
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1812
- if (device->mul_mat ## ID ## _s) \
1843
+ if (device->mul_mat ## ID ## _s[TYPE]) \
1813
1844
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1814
- if (device->mul_mat ## ID ## _l) \
1845
+ if (device->mul_mat ## ID ## _l[TYPE]) \
1815
1846
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1816
- if (device->mul_mat ## ID ## _m) \
1847
+ if (device->mul_mat ## ID ## _m[TYPE]) \
1817
1848
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1818
- if (device->mul_mat ## ID ## _s) \
1849
+ if (device->mul_mat ## ID ## _s[TYPE]) \
1819
1850
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1820
1851
 
1821
1852
  // Create 2 variants, {f16,f32} accumulator
1822
- #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1823
- CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1824
- CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1825
-
1826
- CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1827
- CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1828
- CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1829
- CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1830
-
1831
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1832
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1833
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1834
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1835
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1836
-
1837
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1838
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1839
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1840
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1841
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1842
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1843
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1844
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1845
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1846
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1847
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1848
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1849
-
1850
- // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1851
- if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1852
- CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1853
- CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1854
- CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1855
-
1856
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1857
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1858
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1859
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1860
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1861
-
1862
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1863
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1864
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1865
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1866
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1867
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1868
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1869
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1870
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1871
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1872
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1873
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1874
- }
1853
+ #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1854
+ CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1855
+ CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1856
+
1857
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1858
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1859
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1860
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1861
+
1862
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1863
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1864
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1865
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1866
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1867
+
1868
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1869
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1870
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1871
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1872
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1873
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1874
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1875
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1876
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1877
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1878
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1879
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1880
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1881
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1882
+
1883
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1884
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1885
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1886
+
1887
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1888
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1889
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1890
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1891
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1892
+
1893
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1894
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1895
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1896
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1897
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1898
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1899
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1900
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1901
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1902
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1903
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1904
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1905
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1906
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1875
1907
  #undef CREATE_MM2
1876
1908
  #undef CREATE_MM
1877
1909
  } else {
1878
1910
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1879
- #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1880
- if (device->mul_mat ## ID ## _l) \
1911
+ #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1912
+ if (device->mul_mat ## ID ## _l[TYPE]) \
1881
1913
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1882
- if (device->mul_mat ## ID ## _m) \
1914
+ if (device->mul_mat ## ID ## _m[TYPE]) \
1883
1915
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1884
- if (device->mul_mat ## ID ## _s) \
1916
+ if (device->mul_mat ## ID ## _s[TYPE]) \
1885
1917
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1886
- if (device->mul_mat ## ID ## _l) \
1918
+ if (device->mul_mat ## ID ## _l[TYPE]) \
1887
1919
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1888
- if (device->mul_mat ## ID ## _m) \
1920
+ if (device->mul_mat ## ID ## _m[TYPE]) \
1889
1921
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1890
- if (device->mul_mat ## ID ## _s) \
1922
+ if (device->mul_mat ## ID ## _s[TYPE]) \
1891
1923
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1892
1924
 
1893
- CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1894
- CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1895
- CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1896
- CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1897
-
1898
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1899
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1900
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1901
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1902
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1903
-
1904
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1905
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1906
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1907
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1908
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1909
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1910
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1911
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1912
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1913
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1914
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1915
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1916
-
1917
- // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1918
- if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1919
- CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1920
- CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1921
- CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1922
-
1923
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1924
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1925
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1926
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1927
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1928
-
1929
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1930
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1931
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1932
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1933
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1934
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1935
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1936
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1937
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1938
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1939
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1940
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1941
- }
1925
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1926
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1927
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1928
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1929
+
1930
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1931
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1932
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1933
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1934
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1935
+
1936
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1937
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1938
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1939
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1940
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1941
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1942
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1943
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1944
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1945
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1946
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1947
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1948
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1949
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1950
+
1951
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1952
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1953
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1954
+
1955
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1956
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1957
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1958
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1959
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1960
+
1961
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1962
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1963
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1964
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1965
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1966
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1967
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1968
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1969
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1970
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1971
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1972
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1973
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1974
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1942
1975
  #undef CREATE_MM
1943
1976
  }
1944
1977
 
@@ -1968,6 +2001,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1968
2001
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1969
2002
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1970
2003
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2004
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2005
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1971
2006
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1972
2007
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1973
2008
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
@@ -1988,6 +2023,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1988
2023
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1989
2024
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1990
2025
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2026
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2027
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1991
2028
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1992
2029
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1993
2030
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
@@ -2009,6 +2046,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2009
2046
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2010
2047
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2011
2048
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2049
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2050
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2012
2051
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2013
2052
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2014
2053
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
@@ -2029,6 +2068,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2029
2068
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2030
2069
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
2031
2070
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
2071
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S], "dequant_iq1_s", dequant_iq1_s_len, dequant_iq1_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2072
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M], "dequant_iq1_m", dequant_iq1_m_len, dequant_iq1_m_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2032
2073
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2033
2074
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2034
2075
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
@@ -2045,6 +2086,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2045
2086
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2046
2087
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2047
2088
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2089
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2090
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2048
2091
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2049
2092
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2050
2093
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -2060,6 +2103,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2060
2103
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2061
2104
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2062
2105
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2106
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2107
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2063
2108
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2064
2109
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2065
2110
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -2106,6 +2151,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2106
2151
 
2107
2152
  ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2108
2153
 
2154
+ ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2155
+ ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2109
2156
  ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2110
2157
  ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2111
2158
  ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
@@ -2128,6 +2175,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2128
2175
  ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2129
2176
 
2130
2177
  ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2178
+ ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2131
2179
 
2132
2180
  ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2133
2181
  ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -2145,19 +2193,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
2145
2193
 
2146
2194
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2147
2195
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2196
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2197
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2148
2198
 
2149
2199
  if (device->float_controls_rte_fp16) {
2150
2200
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2151
2201
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2202
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2203
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2152
2204
  } else {
2153
2205
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2154
2206
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2207
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2208
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2155
2209
  }
2156
2210
 
2157
2211
  ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
2158
2212
 
2213
+ ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2214
+
2159
2215
  ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2160
2216
 
2217
+ ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2218
+
2161
2219
  ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2162
2220
  if (device->float_controls_rte_fp16) {
2163
2221
  ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
@@ -2171,6 +2229,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2171
2229
 
2172
2230
  ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2173
2231
 
2232
+ ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2233
+
2174
2234
  for (auto &c : compiles) {
2175
2235
  c.wait();
2176
2236
  }
@@ -2206,6 +2266,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2206
2266
  device->physical_device = physical_devices[dev_num];
2207
2267
  const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
2208
2268
 
2269
+ const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
2270
+ device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
2271
+
2209
2272
  bool fp16_storage = false;
2210
2273
  bool fp16_compute = false;
2211
2274
  bool maintenance4_support = false;
@@ -2623,34 +2686,36 @@ static vk_device ggml_vk_get_device(size_t idx) {
2623
2686
 
2624
2687
  // Shaders
2625
2688
  // Disable matmul tile sizes early if performance low or not supported
2626
- switch (device->vendor_id) {
2689
+ for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
2690
+ switch (device->vendor_id) {
2627
2691
  #ifndef GGML_VULKAN_RUN_TESTS
2628
- case VK_VENDOR_ID_AMD:
2629
- case VK_VENDOR_ID_INTEL:
2630
- device->mul_mat_l = false;
2631
- device->mul_mat_m = true;
2632
- device->mul_mat_s = true;
2633
- device->mul_mat_id_l = false;
2634
- device->mul_mat_id_m = true;
2635
- device->mul_mat_id_s = true;
2636
- break;
2637
- case VK_VENDOR_ID_APPLE:
2638
- device->mul_mat_l = false;
2639
- device->mul_mat_m = true;
2640
- device->mul_mat_s = false;
2641
- device->mul_mat_id_l = false;
2642
- device->mul_mat_id_m = true;
2643
- device->mul_mat_id_s = false;
2644
- break;
2692
+ case VK_VENDOR_ID_AMD:
2693
+ case VK_VENDOR_ID_INTEL:
2694
+ device->mul_mat_l[i] = false;
2695
+ device->mul_mat_m[i] = true;
2696
+ device->mul_mat_s[i] = true;
2697
+ device->mul_mat_id_l[i] = false;
2698
+ device->mul_mat_id_m[i] = true;
2699
+ device->mul_mat_id_s[i] = true;
2700
+ break;
2701
+ case VK_VENDOR_ID_APPLE:
2702
+ device->mul_mat_l[i] = false;
2703
+ device->mul_mat_m[i] = true;
2704
+ device->mul_mat_s[i] = false;
2705
+ device->mul_mat_id_l[i] = false;
2706
+ device->mul_mat_id_m[i] = true;
2707
+ device->mul_mat_id_s[i] = false;
2708
+ break;
2645
2709
  #endif
2646
- default:
2647
- device->mul_mat_l = true;
2648
- device->mul_mat_m = true;
2649
- device->mul_mat_s = true;
2650
- device->mul_mat_id_l = true;
2651
- device->mul_mat_id_m = true;
2652
- device->mul_mat_id_s = true;
2653
- break;
2710
+ default:
2711
+ device->mul_mat_l[i] = true;
2712
+ device->mul_mat_m[i] = true;
2713
+ device->mul_mat_s[i] = true;
2714
+ device->mul_mat_id_l[i] = true;
2715
+ device->mul_mat_id_m[i] = true;
2716
+ device->mul_mat_id_s[i] = true;
2717
+ break;
2718
+ }
2654
2719
  }
2655
2720
 
2656
2721
  ggml_vk_load_shaders(device);
@@ -2780,8 +2845,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2780
2845
  std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
2781
2846
 
2782
2847
  std::string device_name = props2.properties.deviceName.data();
2783
- GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n",
2784
- idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str());
2848
+ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
2849
+ idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
2850
+ props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
2785
2851
 
2786
2852
  if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
2787
2853
  GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@@ -2791,14 +2857,12 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2791
2857
  static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
2792
2858
  static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
2793
2859
 
2794
- void ggml_vk_instance_init() {
2860
+ static void ggml_vk_instance_init() {
2795
2861
  if (vk_instance_initialized) {
2796
2862
  return;
2797
2863
  }
2798
2864
  VK_LOG_DEBUG("ggml_vk_instance_init()");
2799
2865
 
2800
- vk_instance_initialized = true;
2801
-
2802
2866
  uint32_t api_version = vk::enumerateInstanceVersion();
2803
2867
 
2804
2868
  if (api_version < VK_API_VERSION_1_2) {
@@ -2849,6 +2913,7 @@ void ggml_vk_instance_init() {
2849
2913
  GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
2850
2914
  }
2851
2915
  vk_instance.instance = vk::createInstance(instance_create_info);
2916
+ vk_instance_initialized = true;
2852
2917
 
2853
2918
  size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
2854
2919
 
@@ -2873,7 +2938,7 @@ void ggml_vk_instance_init() {
2873
2938
  // Make sure at least one device exists
2874
2939
  if (devices.empty()) {
2875
2940
  std::cerr << "ggml_vulkan: Error: No devices found." << std::endl;
2876
- GGML_ABORT("fatal error");
2941
+ return;
2877
2942
  }
2878
2943
 
2879
2944
  // Default to using all dedicated GPUs
@@ -3007,6 +3072,8 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
3007
3072
  case GGML_TYPE_Q4_K:
3008
3073
  case GGML_TYPE_Q5_K:
3009
3074
  case GGML_TYPE_Q6_K:
3075
+ case GGML_TYPE_IQ1_S:
3076
+ case GGML_TYPE_IQ1_M:
3010
3077
  case GGML_TYPE_IQ2_XXS:
3011
3078
  case GGML_TYPE_IQ2_XS:
3012
3079
  case GGML_TYPE_IQ2_S:
@@ -3061,6 +3128,8 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
3061
3128
  case GGML_TYPE_Q4_K:
3062
3129
  case GGML_TYPE_Q5_K:
3063
3130
  case GGML_TYPE_Q6_K:
3131
+ case GGML_TYPE_IQ1_S:
3132
+ case GGML_TYPE_IQ1_M:
3064
3133
  case GGML_TYPE_IQ2_XXS:
3065
3134
  case GGML_TYPE_IQ2_XS:
3066
3135
  case GGML_TYPE_IQ2_S:
@@ -3098,6 +3167,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
3098
3167
  case GGML_TYPE_Q4_K:
3099
3168
  case GGML_TYPE_Q5_K:
3100
3169
  case GGML_TYPE_Q6_K:
3170
+ case GGML_TYPE_IQ1_S:
3171
+ case GGML_TYPE_IQ1_M:
3101
3172
  case GGML_TYPE_IQ2_XXS:
3102
3173
  case GGML_TYPE_IQ2_XS:
3103
3174
  case GGML_TYPE_IQ2_S:
@@ -3147,6 +3218,8 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
3147
3218
  case GGML_TYPE_Q4_K:
3148
3219
  case GGML_TYPE_Q5_K:
3149
3220
  case GGML_TYPE_Q6_K:
3221
+ case GGML_TYPE_IQ1_S:
3222
+ case GGML_TYPE_IQ1_M:
3150
3223
  case GGML_TYPE_IQ2_XXS:
3151
3224
  case GGML_TYPE_IQ2_XS:
3152
3225
  case GGML_TYPE_IQ2_S:
@@ -3179,6 +3252,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
3179
3252
  case GGML_TYPE_Q4_K:
3180
3253
  case GGML_TYPE_Q5_K:
3181
3254
  case GGML_TYPE_Q6_K:
3255
+ case GGML_TYPE_IQ1_S:
3256
+ case GGML_TYPE_IQ1_M:
3182
3257
  case GGML_TYPE_IQ2_XXS:
3183
3258
  case GGML_TYPE_IQ2_XS:
3184
3259
  case GGML_TYPE_IQ2_S:
@@ -3721,6 +3796,12 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
3721
3796
  }
3722
3797
  }
3723
3798
 
3799
+ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
3800
+ VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
3801
+
3802
+ ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
3803
+ }
3804
+
3724
3805
  static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
3725
3806
  VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
3726
3807
 
@@ -3755,31 +3836,31 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
3755
3836
  return split_k;
3756
3837
  }
3757
3838
 
3758
- static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
3759
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
3839
+ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
3840
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
3760
3841
 
3761
3842
  if (ctx->device->coopmat2) {
3762
- if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
3843
+ if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
3763
3844
  return aligned ? mmp->a_l : mmp->l;
3764
3845
  }
3765
- if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) {
3846
+ if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
3766
3847
  return aligned ? mmp->a_m : mmp->m;
3767
3848
  }
3768
3849
  return aligned ? mmp->a_s : mmp->s;
3769
3850
  }
3770
3851
 
3771
- if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) {
3852
+ if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
3772
3853
  return aligned ? mmp->a_s : mmp->s;
3773
3854
  }
3774
- if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) {
3855
+ if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
3775
3856
  return aligned ? mmp->a_m : mmp->m;
3776
3857
  }
3777
3858
  return aligned ? mmp->a_l : mmp->l;
3778
3859
  }
3779
3860
 
3780
- static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
3781
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
3782
- return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align;
3861
+ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
3862
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
3863
+ return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
3783
3864
  }
3784
3865
 
3785
3866
  static void ggml_vk_matmul(
@@ -3806,31 +3887,31 @@ static void ggml_vk_matmul(
3806
3887
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
3807
3888
  }
3808
3889
 
3809
- static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
3810
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
3890
+ static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
3891
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
3811
3892
 
3812
3893
  if (ctx->device->coopmat2) {
3813
- if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) {
3894
+ if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
3814
3895
  return aligned ? mmp->a_l : mmp->l;
3815
3896
  }
3816
- if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) {
3897
+ if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
3817
3898
  return aligned ? mmp->a_m : mmp->m;
3818
3899
  }
3819
3900
  return aligned ? mmp->a_s : mmp->s;
3820
3901
  }
3821
3902
 
3822
- if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
3903
+ if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
3823
3904
  return aligned ? mmp->a_s : mmp->s;
3824
3905
  }
3825
- if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) {
3906
+ if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
3826
3907
  return aligned ? mmp->a_m : mmp->m;
3827
3908
  }
3828
3909
  return aligned ? mmp->a_l : mmp->l;
3829
3910
  }
3830
3911
 
3831
- static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
3832
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
3833
- return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align;
3912
+ static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
3913
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
3914
+ return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
3834
3915
  }
3835
3916
 
3836
3917
  static void ggml_vk_matmul_id(
@@ -4011,10 +4092,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4011
4092
  const int y_ne = ne11 * ne10;
4012
4093
  const int d_ne = ne11 * ne01;
4013
4094
 
4014
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
4095
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4015
4096
  const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
4016
4097
 
4017
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
4098
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4018
4099
 
4019
4100
  const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
4020
4101
 
@@ -4593,10 +4674,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4593
4674
  const uint64_t y_ne = ne11 * ne10;
4594
4675
  const uint64_t d_ne = ne21 * ne20;
4595
4676
 
4596
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1));
4677
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4597
4678
  const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
4598
4679
 
4599
- vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned);
4680
+ vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4600
4681
 
4601
4682
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
4602
4683
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
@@ -5127,6 +5208,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5127
5208
  return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
5128
5209
  }
5129
5210
  return nullptr;
5211
+ case GGML_OP_SUB:
5212
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5213
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32;
5214
+ }
5215
+ return nullptr;
5130
5216
  case GGML_OP_MUL:
5131
5217
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5132
5218
  return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
@@ -5188,6 +5274,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5188
5274
  return ctx->device->pipeline_repeat_f32;
5189
5275
  }
5190
5276
  return nullptr;
5277
+ case GGML_OP_REPEAT_BACK:
5278
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5279
+ return ctx->device->pipeline_repeat_back_f32;
5280
+ }
5281
+ return nullptr;
5191
5282
  case GGML_OP_CPY:
5192
5283
  case GGML_OP_CONT:
5193
5284
  case GGML_OP_DUP:
@@ -5257,6 +5348,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5257
5348
  {
5258
5349
  const int mode = ((const int32_t *) dst->op_params)[2];
5259
5350
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5351
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5352
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5260
5353
 
5261
5354
  if (is_neox) {
5262
5355
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
@@ -5265,6 +5358,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5265
5358
  if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
5266
5359
  return ctx->device->pipeline_rope_neox_f16;
5267
5360
  }
5361
+ } else if (is_mrope && !is_vision) {
5362
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5363
+ return ctx->device->pipeline_rope_multi_f32;
5364
+ }
5365
+ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
5366
+ return ctx->device->pipeline_rope_multi_f16;
5367
+ }
5368
+ } else if (is_vision) {
5369
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5370
+ return ctx->device->pipeline_rope_vision_f32;
5371
+ }
5372
+ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
5373
+ return ctx->device->pipeline_rope_vision_f16;
5374
+ }
5268
5375
  } else {
5269
5376
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5270
5377
  return ctx->device->pipeline_rope_norm_f32;
@@ -5280,11 +5387,22 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5280
5387
  return ctx->device->pipeline_argsort_f32;
5281
5388
  }
5282
5389
  return nullptr;
5390
+ case GGML_OP_SUM:
5283
5391
  case GGML_OP_SUM_ROWS:
5284
5392
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5285
5393
  return ctx->device->pipeline_sum_rows_f32;
5286
5394
  }
5287
5395
  return nullptr;
5396
+ case GGML_OP_ARGMAX:
5397
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
5398
+ return ctx->device->pipeline_argmax_f32;
5399
+ }
5400
+ return nullptr;
5401
+ case GGML_OP_COUNT_EQUAL:
5402
+ if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
5403
+ return ctx->device->pipeline_count_equal_i32;
5404
+ }
5405
+ return nullptr;
5288
5406
  case GGML_OP_IM2COL:
5289
5407
  if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5290
5408
  return ctx->device->pipeline_im2col_f32;
@@ -5308,6 +5426,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5308
5426
  return ctx->device->pipeline_rwkv_wkv6_f32;
5309
5427
  }
5310
5428
  return nullptr;
5429
+ case GGML_OP_OPT_STEP_ADAMW:
5430
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5431
+ return ctx->device->pipeline_opt_step_adamw_f32;
5432
+ }
5433
+ return nullptr;
5311
5434
  case GGML_OP_LEAKY_RELU:
5312
5435
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5313
5436
  return ctx->device->pipeline_leaky_relu_f32;
@@ -5325,6 +5448,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
5325
5448
  case GGML_OP_CPY:
5326
5449
  case GGML_OP_GET_ROWS:
5327
5450
  case GGML_OP_ADD:
5451
+ case GGML_OP_SUB:
5328
5452
  case GGML_OP_MUL:
5329
5453
  case GGML_OP_DIV:
5330
5454
  case GGML_OP_CONCAT:
@@ -5335,6 +5459,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
5335
5459
  case GGML_OP_CLAMP:
5336
5460
  case GGML_OP_PAD:
5337
5461
  case GGML_OP_REPEAT:
5462
+ case GGML_OP_REPEAT_BACK:
5463
+ case GGML_OP_ROPE:
5338
5464
  return true;
5339
5465
  default:
5340
5466
  return false;
@@ -5548,6 +5674,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5548
5674
  case GGML_OP_RMS_NORM:
5549
5675
  case GGML_OP_SOFT_MAX:
5550
5676
  case GGML_OP_SUM_ROWS:
5677
+ case GGML_OP_ARGMAX:
5551
5678
  {
5552
5679
  const uint32_t nr = ggml_nrows(src0);
5553
5680
  if (nr > 262144) {
@@ -5558,6 +5685,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5558
5685
  elements = { nr, 1, 1 };
5559
5686
  }
5560
5687
  } break;
5688
+ case GGML_OP_SUM:
5689
+ // We use GGML_OP_SUM_ROWS with 1 row.
5690
+ elements = { 1, 1, 1 };
5691
+ break;
5561
5692
  case GGML_OP_GROUP_NORM:
5562
5693
  {
5563
5694
  const uint32_t num_groups = dst->op_params[0];
@@ -5604,6 +5735,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5604
5735
  elements = { N * OC * OH * OW, 1, 1};
5605
5736
  } break;
5606
5737
  case GGML_OP_ADD:
5738
+ case GGML_OP_SUB:
5607
5739
  case GGML_OP_DIV:
5608
5740
  case GGML_OP_MUL:
5609
5741
  case GGML_OP_SCALE:
@@ -5613,6 +5745,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5613
5745
  case GGML_OP_CLAMP:
5614
5746
  case GGML_OP_PAD:
5615
5747
  case GGML_OP_REPEAT:
5748
+ case GGML_OP_REPEAT_BACK:
5616
5749
  case GGML_OP_CPY:
5617
5750
  case GGML_OP_CONCAT:
5618
5751
  case GGML_OP_UPSCALE:
@@ -5673,6 +5806,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5673
5806
  // im2col uses only src1 and dst buffers
5674
5807
  ggml_vk_sync_buffers(subctx);
5675
5808
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5809
+ } else if (op == GGML_OP_COUNT_EQUAL) {
5810
+ ggml_vk_sync_buffers(subctx);
5811
+ // count_equal assumes that destination buffer is initialized with zeroes
5812
+ ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
5813
+ ggml_vk_sync_buffers(subctx);
5814
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5676
5815
  } else if (use_src2) {
5677
5816
  ggml_vk_sync_buffers(subctx);
5678
5817
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
@@ -5735,6 +5874,21 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
5735
5874
  }, dryrun);
5736
5875
  }
5737
5876
 
5877
+ static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
5878
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
5879
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
5880
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
5881
+
5882
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, {
5883
+ (uint32_t)ggml_nelements(src0),
5884
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
5885
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
5886
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
5887
+ 0,
5888
+ 0.0f, 0.0f, 0,
5889
+ }, dryrun);
5890
+ }
5891
+
5738
5892
  static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
5739
5893
  const uint32_t src0_type_size = ggml_type_size(src0->type);
5740
5894
  const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -5893,6 +6047,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
5893
6047
  );
5894
6048
  }
5895
6049
 
6050
+ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
6051
+ const ggml_tensor * x = dst->src[0];
6052
+ const ggml_tensor * g = dst->src[1];
6053
+ const ggml_tensor * gm = dst->src[2];
6054
+ const ggml_tensor * gv = dst->src[3];
6055
+ const ggml_tensor * p = dst->src[4];
6056
+
6057
+ GGML_ASSERT(x->type == GGML_TYPE_F32);
6058
+ GGML_ASSERT(g->type == GGML_TYPE_F32);
6059
+ GGML_ASSERT(gm->type == GGML_TYPE_F32);
6060
+ GGML_ASSERT(gv->type == GGML_TYPE_F32);
6061
+ GGML_ASSERT(p->type == GGML_TYPE_F32);
6062
+ GGML_ASSERT(dst->buffer != nullptr);
6063
+ GGML_ASSERT(ggml_is_contiguous(x));
6064
+ GGML_ASSERT(ggml_is_contiguous(g));
6065
+ GGML_ASSERT(ggml_is_contiguous(gm));
6066
+ GGML_ASSERT(ggml_is_contiguous(gv));
6067
+ GGML_ASSERT(ggml_is_contiguous(p));
6068
+ GGML_ASSERT(ggml_are_same_shape(x, g));
6069
+ GGML_ASSERT(ggml_are_same_shape(x, gm));
6070
+ GGML_ASSERT(ggml_are_same_shape(x, gv));
6071
+ GGML_ASSERT(ggml_nelements(p) == 7);
6072
+
6073
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
6074
+ GGML_ASSERT(pipeline != nullptr);
6075
+
6076
+ if (dryrun) {
6077
+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
6078
+ return;
6079
+ }
6080
+
6081
+ ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context;
6082
+ ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context;
6083
+ ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context;
6084
+ ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
6085
+ ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
6086
+
6087
+ ggml_vk_sync_buffers(subctx);
6088
+
6089
+ vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
6090
+ size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
6091
+ bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
6092
+
6093
+ if (ctx->device->uma) {
6094
+ ggml_vk_host_get(ctx->device, x->data, d_X, x_offset);
6095
+ ggml_vk_host_get(ctx->device, g->data, d_G, g_offset);
6096
+ ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset);
6097
+ ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset);
6098
+ ggml_vk_host_get(ctx->device, p->data, d_P, p_offset);
6099
+
6100
+ X_uma = d_X != nullptr;
6101
+ G_uma = d_G != nullptr;
6102
+ GM_uma = d_GM != nullptr;
6103
+ GV_uma = d_GV != nullptr;
6104
+ P_uma = d_P != nullptr;
6105
+ }
6106
+
6107
+ if (!X_uma) {
6108
+ d_X = x_buf_ctx->dev_buffer;
6109
+ x_offset = vk_tensor_offset(x) + x->view_offs;
6110
+ }
6111
+ if (!G_uma) {
6112
+ d_G = g_buf_ctx->dev_buffer;
6113
+ g_offset = vk_tensor_offset(g) + g->view_offs;
6114
+ }
6115
+ if (!GM_uma) {
6116
+ d_GM = gm_buf_ctx->dev_buffer;
6117
+ gm_offset = vk_tensor_offset(gm) + gm->view_offs;
6118
+ }
6119
+ if (!GV_uma) {
6120
+ d_GV = gv_buf_ctx->dev_buffer;
6121
+ gv_offset = vk_tensor_offset(gv) + gv->view_offs;
6122
+ }
6123
+ if (!P_uma) {
6124
+ d_P = p_buf_ctx->dev_buffer;
6125
+ p_offset = vk_tensor_offset(p) + p->view_offs;
6126
+ }
6127
+
6128
+ const uint64_t x_size = ggml_nbytes(x);
6129
+ const uint64_t g_size = ggml_nbytes(g);
6130
+ const uint64_t gm_size = ggml_nbytes(gm);
6131
+ const uint64_t gv_size = ggml_nbytes(gv);
6132
+ const uint64_t p_size = ggml_nbytes(p);
6133
+
6134
+ std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
6135
+
6136
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6137
+ vk_subbuffer{ d_X, x_offset, x_size },
6138
+ vk_subbuffer{ d_G, g_offset, g_size },
6139
+ vk_subbuffer{ d_GM, gm_offset, gm_size },
6140
+ vk_subbuffer{ d_GV, gv_offset, gv_size },
6141
+ vk_subbuffer{ d_P, p_offset, p_size },
6142
+ }, sizeof(vk_op_push_constants), &pc, elements);
6143
+ }
6144
+
6145
+ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
6146
+ const size_t n = ggml_nelements(dst->src[0]);
6147
+
6148
+ ggml_vk_op_f32_opt_step_adamw(
6149
+ ctx, subctx, dst,
6150
+ { (uint32_t)n, 0, 0.0f, 0.0f },
6151
+ dryrun
6152
+ );
6153
+ }
6154
+
5896
6155
  static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
5897
6156
  int * op_params = (int *)dst->op_params;
5898
6157
 
@@ -6026,6 +6285,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
6026
6285
  }, dryrun);
6027
6286
  }
6028
6287
 
6288
+ static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6289
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
6290
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
6291
+
6292
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
6293
+ (uint32_t)ggml_nelements(dst),
6294
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
6295
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
6296
+ 0,
6297
+ 0.0f, 0.0f,
6298
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
6299
+ }, dryrun);
6300
+ }
6301
+
6029
6302
  static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6030
6303
  const uint32_t src0_type_size = ggml_type_size(src0->type);
6031
6304
  const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -6099,7 +6372,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
6099
6372
 
6100
6373
  static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
6101
6374
  const int n_dims = ((int32_t *) dst->op_params)[1];
6102
- // const int mode = ((int32_t *) dst->op_params)[2];
6375
+ const int mode = ((int32_t *) dst->op_params)[2];
6103
6376
  // const int n_ctx = ((int32_t *) dst->op_params)[3];
6104
6377
  const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
6105
6378
  const float freq_base = ((float *) dst->op_params)[5];
@@ -6108,16 +6381,24 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
6108
6381
  const float attn_factor = ((float *) dst->op_params)[8];
6109
6382
  const float beta_fast = ((float *) dst->op_params)[9];
6110
6383
  const float beta_slow = ((float *) dst->op_params)[10];
6384
+ int sections[4] {};
6385
+ if (mode & GGML_ROPE_TYPE_MROPE) {
6386
+ memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
6387
+ }
6111
6388
 
6112
6389
  float corr_dims[2];
6113
6390
  ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
6114
6391
 
6115
6392
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
6116
6393
 
6394
+ uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
6395
+ uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
6396
+
6117
6397
  ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
6118
6398
  (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
6119
6399
  freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
6120
- src2 != nullptr,
6400
+ src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
6401
+ sections[0], sections[1], sections[2], sections[3],
6121
6402
  }, dryrun);
6122
6403
  }
6123
6404
 
@@ -6140,10 +6421,22 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
6140
6421
  }, dryrun);
6141
6422
  }
6142
6423
 
6424
+ static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6425
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6426
+ }
6427
+
6143
6428
  static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6144
6429
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
6145
6430
  }
6146
6431
 
6432
+ static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6433
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
6434
+ }
6435
+
6436
+ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6437
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6438
+ }
6439
+
6147
6440
  static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6148
6441
  const int32_t s0 = dst->op_params[0];
6149
6442
  const int32_t s1 = dst->op_params[1];
@@ -7008,9 +7301,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7008
7301
  }
7009
7302
  break;
7010
7303
  case GGML_OP_REPEAT:
7304
+ case GGML_OP_REPEAT_BACK:
7011
7305
  case GGML_OP_GET_ROWS:
7012
7306
  case GGML_OP_ADD:
7013
7307
  case GGML_OP_ACC:
7308
+ case GGML_OP_SUB:
7014
7309
  case GGML_OP_MUL:
7015
7310
  case GGML_OP_DIV:
7016
7311
  case GGML_OP_CONCAT:
@@ -7033,13 +7328,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7033
7328
  case GGML_OP_MUL_MAT:
7034
7329
  case GGML_OP_MUL_MAT_ID:
7035
7330
  case GGML_OP_ARGSORT:
7331
+ case GGML_OP_SUM:
7036
7332
  case GGML_OP_SUM_ROWS:
7333
+ case GGML_OP_ARGMAX:
7334
+ case GGML_OP_COUNT_EQUAL:
7037
7335
  case GGML_OP_IM2COL:
7038
7336
  case GGML_OP_TIMESTEP_EMBEDDING:
7039
7337
  case GGML_OP_POOL_2D:
7040
7338
  case GGML_OP_RWKV_WKV6:
7041
7339
  case GGML_OP_LEAKY_RELU:
7042
7340
  case GGML_OP_FLASH_ATTN_EXT:
7341
+ case GGML_OP_OPT_STEP_ADAMW:
7043
7342
  break;
7044
7343
  default:
7045
7344
  std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -7060,9 +7359,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7060
7359
  } else {
7061
7360
  switch (node->op) {
7062
7361
  case GGML_OP_REPEAT:
7362
+ case GGML_OP_REPEAT_BACK:
7063
7363
  case GGML_OP_ACC:
7064
7364
  case GGML_OP_GET_ROWS:
7065
7365
  case GGML_OP_ADD:
7366
+ case GGML_OP_SUB:
7066
7367
  case GGML_OP_MUL:
7067
7368
  case GGML_OP_DIV:
7068
7369
  case GGML_OP_CONCAT:
@@ -7084,7 +7385,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7084
7385
  case GGML_OP_SOFT_MAX:
7085
7386
  case GGML_OP_ROPE:
7086
7387
  case GGML_OP_ARGSORT:
7388
+ case GGML_OP_SUM:
7087
7389
  case GGML_OP_SUM_ROWS:
7390
+ case GGML_OP_ARGMAX:
7391
+ case GGML_OP_COUNT_EQUAL:
7088
7392
  case GGML_OP_IM2COL:
7089
7393
  case GGML_OP_TIMESTEP_EMBEDDING:
7090
7394
  case GGML_OP_POOL_2D:
@@ -7105,6 +7409,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7105
7409
  case GGML_OP_REPEAT:
7106
7410
  ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
7107
7411
 
7412
+ break;
7413
+ case GGML_OP_REPEAT_BACK:
7414
+ ggml_vk_repeat_back(ctx, compute_ctx, src0, node, dryrun);
7415
+
7108
7416
  break;
7109
7417
  case GGML_OP_ACC:
7110
7418
  ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7117,6 +7425,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7117
7425
  case GGML_OP_ADD:
7118
7426
  ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
7119
7427
 
7428
+ break;
7429
+ case GGML_OP_SUB:
7430
+ ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
7431
+
7120
7432
  break;
7121
7433
  case GGML_OP_MUL:
7122
7434
  ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7204,10 +7516,22 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7204
7516
  case GGML_OP_ARGSORT:
7205
7517
  ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
7206
7518
 
7519
+ break;
7520
+ case GGML_OP_SUM:
7521
+ ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun);
7522
+
7207
7523
  break;
7208
7524
  case GGML_OP_SUM_ROWS:
7209
7525
  ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
7210
7526
 
7527
+ break;
7528
+ case GGML_OP_ARGMAX:
7529
+ ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
7530
+
7531
+ break;
7532
+ case GGML_OP_COUNT_EQUAL:
7533
+ ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun);
7534
+
7211
7535
  break;
7212
7536
  case GGML_OP_IM2COL:
7213
7537
  ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7242,6 +7566,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7242
7566
  case GGML_OP_RWKV_WKV6:
7243
7567
  ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
7244
7568
 
7569
+ break;
7570
+
7571
+ case GGML_OP_OPT_STEP_ADAMW:
7572
+ ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
7573
+
7245
7574
  break;
7246
7575
  default:
7247
7576
  return false;
@@ -7293,6 +7622,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7293
7622
  case GGML_OP_ADD:
7294
7623
  case GGML_OP_ACC:
7295
7624
  case GGML_OP_GET_ROWS:
7625
+ case GGML_OP_SUB:
7296
7626
  case GGML_OP_MUL:
7297
7627
  case GGML_OP_DIV:
7298
7628
  case GGML_OP_CONCAT:
@@ -7318,13 +7648,18 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7318
7648
  case GGML_OP_TRANSPOSE:
7319
7649
  case GGML_OP_NONE:
7320
7650
  case GGML_OP_ARGSORT:
7651
+ case GGML_OP_SUM:
7321
7652
  case GGML_OP_SUM_ROWS:
7653
+ case GGML_OP_ARGMAX:
7654
+ case GGML_OP_COUNT_EQUAL:
7322
7655
  case GGML_OP_IM2COL:
7323
7656
  case GGML_OP_TIMESTEP_EMBEDDING:
7324
7657
  case GGML_OP_POOL_2D:
7325
7658
  case GGML_OP_RWKV_WKV6:
7326
7659
  case GGML_OP_LEAKY_RELU:
7327
7660
  case GGML_OP_REPEAT:
7661
+ case GGML_OP_REPEAT_BACK:
7662
+ case GGML_OP_OPT_STEP_ADAMW:
7328
7663
  buf = tensor->buffer;
7329
7664
 
7330
7665
  break;
@@ -7516,6 +7851,15 @@ static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggm
7516
7851
  }
7517
7852
  }
7518
7853
 
7854
+ static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
7855
+ VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
7856
+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
7857
+ vk_buffer buf = buf_ctx->dev_buffer;
7858
+
7859
+ uint32_t val32 = (uint32_t)value * 0x01010101;
7860
+ ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
7861
+ }
7862
+
7519
7863
  static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
7520
7864
  VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
7521
7865
  ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
@@ -7560,7 +7904,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
7560
7904
  /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
7561
7905
  /* .get_base = */ ggml_backend_vk_buffer_get_base,
7562
7906
  /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
7563
- /* .memset_tensor = */ NULL,
7907
+ /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor,
7564
7908
  /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
7565
7909
  /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
7566
7910
  /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
@@ -8035,13 +8379,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8035
8379
  case GGML_OP_MUL_MAT:
8036
8380
  case GGML_OP_MUL_MAT_ID:
8037
8381
  {
8382
+ ggml_type src0_type = op->src[0]->type;
8038
8383
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
8039
8384
  const vk_device& device = ggml_vk_get_device(ctx->device);
8040
- if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) {
8385
+ if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
8041
8386
  // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
8042
8387
  return false;
8043
8388
  }
8044
- switch (op->src[0]->type) {
8389
+ switch (src0_type) {
8045
8390
  case GGML_TYPE_F32:
8046
8391
  case GGML_TYPE_F16:
8047
8392
  case GGML_TYPE_Q4_0:
@@ -8054,6 +8399,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8054
8399
  case GGML_TYPE_Q4_K:
8055
8400
  case GGML_TYPE_Q5_K:
8056
8401
  case GGML_TYPE_Q6_K:
8402
+ case GGML_TYPE_IQ1_S:
8403
+ case GGML_TYPE_IQ1_M:
8057
8404
  case GGML_TYPE_IQ2_XXS:
8058
8405
  case GGML_TYPE_IQ2_XS:
8059
8406
  case GGML_TYPE_IQ2_S:
@@ -8128,6 +8475,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8128
8475
  //case GGML_TYPE_Q4_K:
8129
8476
  //case GGML_TYPE_Q5_K:
8130
8477
  //case GGML_TYPE_Q6_K:
8478
+ //case GGML_TYPE_IQ1_S:
8479
+ //case GGML_TYPE_IQ1_M:
8131
8480
  //case GGML_TYPE_IQ2_XXS:
8132
8481
  //case GGML_TYPE_IQ2_XS:
8133
8482
  //case GGML_TYPE_IQ2_S:
@@ -8151,6 +8500,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8151
8500
  case GGML_TYPE_Q5_0:
8152
8501
  case GGML_TYPE_Q5_1:
8153
8502
  case GGML_TYPE_Q8_0:
8503
+ case GGML_TYPE_IQ1_S:
8504
+ case GGML_TYPE_IQ1_M:
8154
8505
  case GGML_TYPE_IQ2_XXS:
8155
8506
  case GGML_TYPE_IQ2_XS:
8156
8507
  case GGML_TYPE_IQ2_S:
@@ -8206,17 +8557,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8206
8557
  } break;
8207
8558
  case GGML_OP_REPEAT:
8208
8559
  return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
8560
+ case GGML_OP_REPEAT_BACK:
8561
+ return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
8209
8562
  case GGML_OP_ROPE:
8210
- {
8211
- const int mode = ((const int32_t *) op->op_params)[2];
8212
- if (mode & GGML_ROPE_TYPE_MROPE) {
8213
- return false;
8214
- }
8215
- if (mode & GGML_ROPE_TYPE_VISION) {
8216
- return false;
8217
- }
8218
- return ggml_is_contiguous(op->src[0]);
8219
- }
8220
8563
  case GGML_OP_NONE:
8221
8564
  case GGML_OP_RESHAPE:
8222
8565
  case GGML_OP_VIEW:
@@ -8229,6 +8572,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8229
8572
  return ggml_is_contiguous(op->src[0]);
8230
8573
  case GGML_OP_ADD:
8231
8574
  case GGML_OP_ACC:
8575
+ case GGML_OP_SUB:
8232
8576
  case GGML_OP_MUL:
8233
8577
  case GGML_OP_DIV:
8234
8578
  case GGML_OP_CONCAT:
@@ -8242,12 +8586,16 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8242
8586
  case GGML_OP_DIAG_MASK_INF:
8243
8587
  case GGML_OP_SOFT_MAX:
8244
8588
  case GGML_OP_ARGSORT:
8589
+ case GGML_OP_SUM:
8245
8590
  case GGML_OP_SUM_ROWS:
8591
+ case GGML_OP_ARGMAX:
8592
+ case GGML_OP_COUNT_EQUAL:
8246
8593
  case GGML_OP_IM2COL:
8247
8594
  case GGML_OP_TIMESTEP_EMBEDDING:
8248
8595
  case GGML_OP_POOL_2D:
8249
8596
  case GGML_OP_RWKV_WKV6:
8250
8597
  case GGML_OP_LEAKY_RELU:
8598
+ case GGML_OP_OPT_STEP_ADAMW:
8251
8599
  return true;
8252
8600
  default:
8253
8601
  return false;
@@ -8347,8 +8695,13 @@ ggml_backend_reg_t ggml_backend_vk_reg() {
8347
8695
  /* .iface = */ ggml_backend_vk_reg_i,
8348
8696
  /* .context = */ nullptr,
8349
8697
  };
8350
-
8351
- return &reg;
8698
+ try {
8699
+ ggml_vk_instance_init();
8700
+ return &reg;
8701
+ } catch (const vk::SystemError& e) {
8702
+ VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what());
8703
+ return nullptr;
8704
+ }
8352
8705
  }
8353
8706
 
8354
8707
  // Extension availability
@@ -8515,8 +8868,6 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8515
8868
 
8516
8869
  ggml_tensor * src0 = tensor->src[0];
8517
8870
  ggml_tensor * src1 = tensor->src[1];
8518
- ggml_tensor * src2 = tensor->src[2];
8519
- ggml_tensor * src3 = tensor->src[3];
8520
8871
 
8521
8872
  struct ggml_init_params iparams = {
8522
8873
  /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
@@ -8526,238 +8877,113 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8526
8877
 
8527
8878
  struct ggml_context * ggml_ctx = ggml_init(iparams);
8528
8879
 
8529
- struct ggml_tensor * src0_clone = nullptr;
8530
- struct ggml_tensor * src1_clone = nullptr;
8531
- struct ggml_tensor * src2_clone = nullptr;
8532
- struct ggml_tensor * src3_clone = nullptr;
8533
- struct ggml_tensor * tensor_clone = nullptr;
8534
-
8535
- size_t src0_size;
8536
- size_t src1_size;
8537
- size_t src2_size;
8538
- size_t src3_size;
8539
-
8540
- void * src0_buffer = nullptr;
8541
- void * src1_buffer = nullptr;
8542
- void * src2_buffer = nullptr;
8543
- void * src3_buffer = nullptr;
8544
-
8545
- if (src0 != nullptr) {
8546
- src0_clone = ggml_dup_tensor(ggml_ctx, src0);
8547
-
8548
- src0_size = ggml_nbytes(src0);
8880
+ std::array<struct ggml_tensor *, 6> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
8881
+ std::array<size_t, 6> src_size = {0, 0, 0, 0, 0, 0};
8882
+ std::array<void *, 6> src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
8883
+ const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"};
8549
8884
 
8550
- src0_buffer = malloc(src0_size);
8551
- src0_clone->data = src0_buffer;
8552
- if (ggml_backend_buffer_is_host(src0->buffer)) {
8553
- memcpy(src0_clone->data, src0->data, src0_size);
8554
- memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
8555
- } else if (ggml_backend_buffer_is_vk(src0->buffer)) {
8556
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
8557
- vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
8558
- uint64_t offset = vk_tensor_offset(src0) + src0->view_offs;
8559
- if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
8560
- for (int i3 = 0; i3 < src0->ne[3]; i3++) {
8561
- for (int i2 = 0; i2 < src0->ne[2]; i2++) {
8562
- const int idx = i3*src0->ne[2] + i2;
8563
- ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
8564
- }
8565
- }
8566
-
8567
- src0_clone->nb[0] = src0->nb[0];
8568
- src0_clone->nb[1] = src0->nb[1];
8569
- for (int i = 2; i < GGML_MAX_DIMS; i++) {
8570
- src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1];
8571
- }
8572
- } else {
8573
- if (offset + src0_size >= buffer_gpu->size) {
8574
- src0_size = buffer_gpu->size - offset;
8575
- }
8576
- ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size);
8577
- memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
8578
- }
8579
- } else {
8580
- GGML_ABORT("fatal error");
8581
- }
8582
-
8583
- if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
8584
- ggml_vk_print_tensor(src0, "src0");
8585
- }
8586
- }
8587
- if (src1 != nullptr) {
8588
- src1_clone = ggml_dup_tensor(ggml_ctx, src1);
8589
-
8590
- src1_size = ggml_nbytes(src1);
8591
-
8592
- src1_buffer = malloc(src1_size);
8593
- src1_clone->data = src1_buffer;
8594
- if (ggml_backend_buffer_is_host(src1->buffer)) {
8595
- memcpy(src1_clone->data, src1->data, src1_size);
8596
- memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
8597
- } else if (ggml_backend_buffer_is_vk(src1->buffer)) {
8598
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
8599
- vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
8600
- uint64_t offset = vk_tensor_offset(src1) + src1->view_offs;
8601
- if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
8602
- for (int i3 = 0; i3 < src1->ne[3]; i3++) {
8603
- for (int i2 = 0; i2 < src1->ne[2]; i2++) {
8604
- const int idx = i3*src1->ne[2] + i2;
8605
- ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
8606
- }
8607
- }
8608
-
8609
- src1_clone->nb[0] = src1->nb[0];
8610
- src1_clone->nb[1] = src1->nb[1];
8611
- for (int i = 2; i < GGML_MAX_DIMS; i++) {
8612
- src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1];
8613
- }
8614
- } else {
8615
- if (offset + src1_size >= buffer_gpu->size) {
8616
- src1_size = buffer_gpu->size - offset;
8617
- }
8618
- ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size);
8619
- memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
8620
- }
8621
- } else {
8622
- GGML_ABORT("fatal error");
8623
- }
8624
-
8625
- if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
8626
- ggml_vk_print_tensor(src1, "src1");
8627
- }
8628
- }
8629
- if (src2 != nullptr) {
8630
- src2_clone = ggml_dup_tensor(ggml_ctx, src2);
8631
-
8632
- src2_size = ggml_nbytes(src2);
8633
-
8634
- src2_buffer = malloc(src2_size);
8635
- src2_clone->data = src2_buffer;
8636
- if (ggml_backend_buffer_is_host(src2->buffer)) {
8637
- memcpy(src2_clone->data, src2->data, src2_size);
8638
- memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
8639
- } else if (ggml_backend_buffer_is_vk(src2->buffer)) {
8640
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context;
8641
- vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
8642
- uint64_t offset = vk_tensor_offset(src2) + src2->view_offs;
8643
- if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
8644
- for (int i3 = 0; i3 < src2->ne[3]; i3++) {
8645
- for (int i2 = 0; i2 < src2->ne[2]; i2++) {
8646
- const int idx = i3*src2->ne[2] + i2;
8647
- ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
8648
- }
8649
- }
8650
-
8651
- src2_clone->nb[0] = src2->nb[0];
8652
- src2_clone->nb[1] = src2->nb[1];
8653
- for (int i = 2; i < GGML_MAX_DIMS; i++) {
8654
- src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1];
8655
- }
8656
- } else {
8657
- if (offset + src2_size >= buffer_gpu->size) {
8658
- src2_size = buffer_gpu->size - offset;
8659
- }
8660
- ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size);
8661
- memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
8662
- }
8663
- } else {
8664
- GGML_ABORT("fatal error");
8665
- }
8885
+ struct ggml_tensor * tensor_clone = nullptr;
8666
8886
 
8667
- if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
8668
- ggml_vk_print_tensor(src2, "src2");
8887
+ for (int i = 0; i < 6; i++) {
8888
+ ggml_tensor * srci = tensor->src[i];
8889
+ if (srci == nullptr) {
8890
+ continue;
8669
8891
  }
8670
- }
8671
- if (src3 != nullptr) {
8672
- src3_clone = ggml_dup_tensor(ggml_ctx, src3);
8892
+ ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci);
8893
+ size_t srci_size = ggml_nbytes(srci);
8673
8894
 
8674
- src3_size = ggml_nbytes(src3);
8895
+ src_clone[i] = srci_clone;
8896
+ src_size[i] = ggml_nbytes(srci);
8897
+ src_buffer[i] = malloc(srci_size);
8675
8898
 
8676
- src3_buffer = malloc(src3_size);
8677
- src3_clone->data = src3_buffer;
8678
- if (ggml_backend_buffer_is_host(src3->buffer)) {
8679
- memcpy(src3_clone->data, src3->data, src3_size);
8680
- memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
8681
- } else if (ggml_backend_buffer_is_vk(src3->buffer)) {
8682
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context;
8899
+ srci_clone->data = src_buffer[i];
8900
+ if (ggml_backend_buffer_is_host(srci->buffer)) {
8901
+ memcpy(srci_clone->data, srci->data, srci_size);
8902
+ memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
8903
+ } else if (ggml_backend_buffer_is_vk(srci->buffer)) {
8904
+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context;
8683
8905
  vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
8684
- uint64_t offset = vk_tensor_offset(src3) + src3->view_offs;
8685
- if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) {
8686
- for (int i3 = 0; i3 < src3->ne[3]; i3++) {
8687
- for (int i2 = 0; i2 < src3->ne[2]; i2++) {
8688
- const int idx = i3*src3->ne[2] + i2;
8689
- ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]);
8906
+ uint64_t offset = vk_tensor_offset(srci) + srci->view_offs;
8907
+ if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) {
8908
+ for (int i3 = 0; i3 < srci->ne[3]; i3++) {
8909
+ for (int i2 = 0; i2 < srci->ne[2]; i2++) {
8910
+ const int idx = i3*srci->ne[2] + i2;
8911
+ ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]);
8690
8912
  }
8691
8913
  }
8692
8914
 
8693
- src3_clone->nb[0] = src3->nb[0];
8694
- src3_clone->nb[1] = src3->nb[1];
8915
+ srci_clone->nb[0] = srci->nb[0];
8916
+ srci_clone->nb[1] = srci->nb[1];
8695
8917
  for (int i = 2; i < GGML_MAX_DIMS; i++) {
8696
- src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1];
8918
+ srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1];
8697
8919
  }
8698
8920
  } else {
8699
- if (offset + src3_size >= buffer_gpu->size) {
8700
- src3_size = buffer_gpu->size - offset;
8921
+ if (offset + srci_size >= buffer_gpu->size) {
8922
+ srci_size = buffer_gpu->size - offset;
8701
8923
  }
8702
- ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size);
8703
- memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
8924
+ ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size);
8925
+ memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
8704
8926
  }
8705
8927
  } else {
8706
8928
  GGML_ABORT("fatal error");
8707
8929
  }
8708
8930
 
8709
8931
  if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
8710
- ggml_vk_print_tensor(src3, "src3");
8932
+ ggml_vk_print_tensor(srci, srci_name[i]);
8711
8933
  }
8712
8934
  }
8713
8935
 
8714
8936
  if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
8715
8937
  const float *params = (const float *)tensor->op_params;
8716
- tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]);
8938
+ tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
8717
8939
  } else if (tensor->op == GGML_OP_MUL_MAT) {
8718
- tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
8940
+ tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
8719
8941
  } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
8720
- tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
8942
+ tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
8943
+ } else if (tensor->op == GGML_OP_SUB) {
8944
+ tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
8721
8945
  } else if (tensor->op == GGML_OP_MUL) {
8722
- tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
8946
+ tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
8723
8947
  } else if (tensor->op == GGML_OP_DIV) {
8724
- tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone);
8948
+ tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
8725
8949
  } else if (tensor->op == GGML_OP_CONCAT) {
8726
- tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params);
8950
+ tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
8727
8951
  } else if (tensor->op == GGML_OP_UPSCALE) {
8728
- tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
8952
+ tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
8729
8953
  } else if (tensor->op == GGML_OP_SCALE) {
8730
- tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
8954
+ tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]);
8731
8955
  } else if (tensor->op == GGML_OP_SQR) {
8732
- tensor_clone = ggml_sqr(ggml_ctx, src0_clone);
8956
+ tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
8733
8957
  } else if (tensor->op == GGML_OP_SIN) {
8734
- tensor_clone = ggml_sin(ggml_ctx, src0_clone);
8958
+ tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
8735
8959
  } else if (tensor->op == GGML_OP_COS) {
8736
- tensor_clone = ggml_cos(ggml_ctx, src0_clone);
8960
+ tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
8737
8961
  } else if (tensor->op == GGML_OP_CLAMP) {
8738
- tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8962
+ tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8739
8963
  } else if (tensor->op == GGML_OP_PAD) {
8740
- tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]);
8964
+ tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
8741
8965
  } else if (tensor->op == GGML_OP_REPEAT) {
8742
- tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor);
8966
+ tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
8967
+ } else if (tensor->op == GGML_OP_REPEAT_BACK) {
8968
+ tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);
8743
8969
  } else if (tensor->op == GGML_OP_ADD) {
8744
- tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone);
8970
+ tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
8745
8971
  } else if (tensor->op == GGML_OP_ACC) {
8746
- tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
8972
+ tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
8747
8973
  } else if (tensor->op == GGML_OP_NORM) {
8748
- tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
8974
+ tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
8749
8975
  } else if (tensor->op == GGML_OP_GROUP_NORM) {
8750
- tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
8976
+ tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
8751
8977
  } else if (tensor->op == GGML_OP_RMS_NORM) {
8752
- tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
8978
+ tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
8753
8979
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
8754
8980
  if (src1 != nullptr) {
8755
- tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8981
+ tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8756
8982
  } else {
8757
- tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
8983
+ tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
8758
8984
  }
8759
8985
  } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
8760
- tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params);
8986
+ tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
8761
8987
  } else if (tensor->op == GGML_OP_ROPE) {
8762
8988
  const int n_dims = ((int32_t *) tensor->op_params)[1];
8763
8989
  const int mode = ((int32_t *) tensor->op_params)[2];
@@ -8769,23 +8995,28 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8769
8995
  const float attn_factor = ((float *) tensor->op_params)[8];
8770
8996
  const float beta_fast = ((float *) tensor->op_params)[9];
8771
8997
  const float beta_slow = ((float *) tensor->op_params)[10];
8772
- tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
8998
+ if (mode & GGML_ROPE_TYPE_MROPE) {
8999
+ int32_t *sections = ((int32_t *) tensor->op_params) + 11;
9000
+ tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9001
+ } else {
9002
+ tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9003
+ }
8773
9004
  } else if (tensor->op == GGML_OP_UNARY) {
8774
9005
  switch (ggml_get_unary_op(tensor)) {
8775
9006
  case GGML_UNARY_OP_SILU:
8776
- tensor_clone = ggml_silu(ggml_ctx, src0_clone);
9007
+ tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
8777
9008
  break;
8778
9009
  case GGML_UNARY_OP_GELU:
8779
- tensor_clone = ggml_gelu(ggml_ctx, src0_clone);
9010
+ tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
8780
9011
  break;
8781
9012
  case GGML_UNARY_OP_GELU_QUICK:
8782
- tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone);
9013
+ tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
8783
9014
  break;
8784
9015
  case GGML_UNARY_OP_RELU:
8785
- tensor_clone = ggml_relu(ggml_ctx, src0_clone);
9016
+ tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
8786
9017
  break;
8787
9018
  case GGML_UNARY_OP_TANH:
8788
- tensor_clone = ggml_tanh(ggml_ctx, src0_clone);
9019
+ tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
8789
9020
  break;
8790
9021
  default:
8791
9022
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
@@ -8793,28 +9024,34 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8793
9024
  }
8794
9025
  } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
8795
9026
  if (src1 == nullptr) {
8796
- tensor_clone = ggml_dup(ggml_ctx, src0_clone);
9027
+ tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
8797
9028
  tensor_clone->type = tensor->type;
8798
9029
  } else {
8799
- tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone);
9030
+ tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
8800
9031
  }
8801
9032
  } else if (tensor->op == GGML_OP_CONT) {
8802
- tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
9033
+ tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
8803
9034
  } else if (tensor->op == GGML_OP_RESHAPE) {
8804
- tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
9035
+ tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
8805
9036
  } else if (tensor->op == GGML_OP_VIEW) {
8806
- tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
9037
+ tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
8807
9038
  } else if (tensor->op == GGML_OP_PERMUTE) {
8808
9039
  int32_t * params = (int32_t *)tensor->op_params;
8809
- tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]);
9040
+ tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]);
8810
9041
  } else if (tensor->op == GGML_OP_TRANSPOSE) {
8811
- tensor_clone = ggml_transpose(ggml_ctx, src0_clone);
9042
+ tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]);
8812
9043
  } else if (tensor->op == GGML_OP_GET_ROWS) {
8813
- tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone);
9044
+ tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
8814
9045
  } else if (tensor->op == GGML_OP_ARGSORT) {
8815
- tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
9046
+ tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
9047
+ } else if (tensor->op == GGML_OP_SUM) {
9048
+ tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
8816
9049
  } else if (tensor->op == GGML_OP_SUM_ROWS) {
8817
- tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
9050
+ tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
9051
+ } else if (tensor->op == GGML_OP_ARGMAX) {
9052
+ tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
9053
+ } else if (tensor->op == GGML_OP_COUNT_EQUAL) {
9054
+ tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
8818
9055
  } else if (tensor->op == GGML_OP_IM2COL) {
8819
9056
  const int32_t s0 = tensor->op_params[0];
8820
9057
  const int32_t s1 = tensor->op_params[1];
@@ -8824,11 +9061,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8824
9061
  const int32_t d1 = tensor->op_params[5];
8825
9062
 
8826
9063
  const bool is_2D = tensor->op_params[6] == 1;
8827
- tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
9064
+ tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
8828
9065
  } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
8829
9066
  const int32_t dim = tensor->op_params[0];
8830
9067
  const int32_t max_period = tensor->op_params[1];
8831
- tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
9068
+ tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
8832
9069
  } else if (tensor->op == GGML_OP_POOL_2D) {
8833
9070
  enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
8834
9071
  const int32_t k0 = tensor->op_params[1];
@@ -8838,13 +9075,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8838
9075
  const int32_t p0 = tensor->op_params[5];
8839
9076
  const int32_t p1 = tensor->op_params[6];
8840
9077
 
8841
- tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1);
9078
+ tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
8842
9079
  } else if (tensor->op == GGML_OP_LEAKY_RELU) {
8843
9080
  const float * op_params = (const float *)tensor->op_params;
8844
- tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
9081
+ tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
8845
9082
  } else if (tensor->op == GGML_OP_RWKV_WKV6) {
8846
- tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
8847
- tensor->src[4], tensor->src[5]);
9083
+ tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
9084
+ src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
9085
+ } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
9086
+ src_clone[0]->flags = src0->flags;
9087
+ tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
9088
+ src_clone[2], src_clone[3], src_clone[4]);
8848
9089
  }
8849
9090
  else {
8850
9091
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
@@ -8866,11 +9107,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8866
9107
  memcpy(comp_result, tensor_clone->data, comp_size);
8867
9108
  memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
8868
9109
 
8869
- if (src0 != nullptr) {
8870
- free(src0_buffer);
8871
- }
8872
- if (src1 != nullptr) {
8873
- free(src1_buffer);
9110
+ for (int i = 0; i < 6; i++) {
9111
+ if (src_buffer[i] != nullptr) {
9112
+ free(src_buffer[i]);
9113
+ }
8874
9114
  }
8875
9115
 
8876
9116
  ggml_free(ggml_ctx);
@@ -8934,6 +9174,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
8934
9174
  } else if (tensor->type == GGML_TYPE_I32) {
8935
9175
  correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
8936
9176
  result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
9177
+ } else if (tensor->type == GGML_TYPE_I64) {
9178
+ correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
9179
+ result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
8937
9180
  } else {
8938
9181
  std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
8939
9182
  }