@fugood/llama.node 0.3.12 → 0.3.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +2 -1
  18. package/package.json +1 -1
  19. package/src/LlamaCompletionWorker.cpp +14 -0
  20. package/src/LlamaContext.cpp +110 -79
  21. package/src/LlamaContext.h +1 -1
  22. package/src/common.hpp +1 -2
  23. package/src/llama.cpp/.github/workflows/build.yml +95 -13
  24. package/src/llama.cpp/.github/workflows/docker.yml +2 -0
  25. package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  27. package/src/llama.cpp/common/CMakeLists.txt +23 -6
  28. package/src/llama.cpp/common/arg.cpp +292 -14
  29. package/src/llama.cpp/common/chat.cpp +1128 -315
  30. package/src/llama.cpp/common/chat.h +135 -0
  31. package/src/llama.cpp/common/common.cpp +27 -171
  32. package/src/llama.cpp/common/common.h +41 -73
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  34. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  35. package/src/llama.cpp/common/llguidance.cpp +3 -3
  36. package/src/llama.cpp/common/log.cpp +1 -0
  37. package/src/llama.cpp/common/log.h +2 -1
  38. package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +21 -7
  39. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +61 -14
  40. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  41. package/src/llama.cpp/common/sampling.cpp +93 -49
  42. package/src/llama.cpp/common/speculative.cpp +6 -5
  43. package/src/llama.cpp/common/speculative.h +1 -1
  44. package/src/llama.cpp/docs/build.md +47 -9
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
  47. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  48. package/src/llama.cpp/examples/imatrix/imatrix.cpp +4 -4
  49. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
  50. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
  52. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  53. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  54. package/src/llama.cpp/examples/llava/clip.h +19 -3
  55. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  56. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  57. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  58. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
  59. package/src/llama.cpp/examples/main/main.cpp +73 -28
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  64. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  65. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  66. package/src/llama.cpp/examples/run/run.cpp +115 -79
  67. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/server/httplib.h +381 -292
  69. package/src/llama.cpp/examples/server/server.cpp +134 -128
  70. package/src/llama.cpp/examples/server/utils.hpp +95 -106
  71. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  72. package/src/llama.cpp/examples/tts/tts.cpp +251 -142
  73. package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
  74. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  75. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  76. package/src/llama.cpp/ggml/include/ggml-cpu.h +4 -1
  77. package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
  78. package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
  79. package/src/llama.cpp/ggml/include/ggml.h +6 -2
  80. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  81. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  82. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  83. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  84. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  85. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
  86. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  87. package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
  88. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
  89. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  90. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +156 -11
  93. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +2235 -641
  94. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1572 -198
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +24 -5
  96. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  97. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
  101. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +16 -3
  102. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  103. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  104. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  105. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
  106. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
  107. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  108. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +246 -120
  109. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  110. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  111. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
  112. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  114. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
  115. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  116. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  117. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  120. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
  121. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +174 -728
  124. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
  125. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
  126. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  127. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  128. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +949 -602
  129. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +37 -3
  130. package/src/llama.cpp/ggml/src/ggml.c +9 -4
  131. package/src/llama.cpp/include/llama.h +32 -14
  132. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  133. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  134. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  135. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  136. package/src/llama.cpp/requirements.txt +1 -0
  137. package/src/llama.cpp/src/llama-arch.cpp +21 -0
  138. package/src/llama.cpp/src/llama-arch.h +1 -0
  139. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  140. package/src/llama.cpp/src/llama-grammar.cpp +183 -183
  141. package/src/llama.cpp/src/llama-grammar.h +13 -4
  142. package/src/llama.cpp/src/llama-impl.h +6 -6
  143. package/src/llama.cpp/src/llama-kv-cache.h +2 -1
  144. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  145. package/src/llama.cpp/src/llama-mmap.h +1 -0
  146. package/src/llama.cpp/src/llama-model.cpp +70 -6
  147. package/src/llama.cpp/src/llama-sampling.cpp +174 -67
  148. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  149. package/src/llama.cpp/src/llama.cpp +154 -5
  150. package/src/llama.cpp/src/unicode.cpp +9 -2
  151. package/src/llama.cpp/tests/test-backend-ops.cpp +171 -115
  152. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  153. package/src/llama.cpp/tests/test-chat.cpp +691 -325
  154. package/src/llama.cpp/tests/test-gguf.cpp +4 -4
  155. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  156. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  157. package/src/llama.cpp/tests/test-sampling.cpp +15 -0
  158. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  159. package/src/llama.cpp/common/chat.hpp +0 -52
@@ -167,6 +167,7 @@ struct vk_device_struct {
167
167
  uint32_t subgroup_size;
168
168
  uint32_t shader_core_count;
169
169
  bool uma;
170
+ bool prefer_host_memory;
170
171
  bool float_controls_rte_fp16;
171
172
 
172
173
  bool subgroup_size_control;
@@ -184,12 +185,12 @@ struct vk_device_struct {
184
185
 
185
186
  size_t idx;
186
187
 
187
- bool mul_mat_l;
188
- bool mul_mat_m;
189
- bool mul_mat_s;
190
- bool mul_mat_id_l;
191
- bool mul_mat_id_m;
192
- bool mul_mat_id_s;
188
+ bool mul_mat_l[GGML_TYPE_COUNT];
189
+ bool mul_mat_m[GGML_TYPE_COUNT];
190
+ bool mul_mat_s[GGML_TYPE_COUNT];
191
+ bool mul_mat_id_l[GGML_TYPE_COUNT];
192
+ bool mul_mat_id_m[GGML_TYPE_COUNT];
193
+ bool mul_mat_id_s[GGML_TYPE_COUNT];
193
194
 
194
195
  // set to true to indicate that some shaders need to be compiled after the dryrun
195
196
  bool need_compiles {};
@@ -221,6 +222,7 @@ struct vk_device_struct {
221
222
  vk_pipeline pipeline_acc_f32;
222
223
  vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
223
224
  vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
225
+ vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
224
226
  vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
225
227
  vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
226
228
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
@@ -231,7 +233,7 @@ struct vk_device_struct {
231
233
  vk_pipeline pipeline_cos_f32;
232
234
  vk_pipeline pipeline_clamp_f32;
233
235
  vk_pipeline pipeline_pad_f32;
234
- vk_pipeline pipeline_repeat_f32;
236
+ vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
235
237
  vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
236
238
  vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
237
239
  vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
@@ -239,23 +241,32 @@ struct vk_device_struct {
239
241
  vk_pipeline pipeline_norm_f32;
240
242
  vk_pipeline pipeline_group_norm_f32;
241
243
  vk_pipeline pipeline_rms_norm_f32;
244
+ vk_pipeline pipeline_rms_norm_back_f32;
242
245
  vk_pipeline pipeline_gelu_f32;
243
246
  vk_pipeline pipeline_gelu_quick_f32;
244
247
  vk_pipeline pipeline_silu_f32;
248
+ vk_pipeline pipeline_silu_back_f32;
245
249
  vk_pipeline pipeline_relu_f32;
246
250
  vk_pipeline pipeline_leaky_relu_f32;
247
251
  vk_pipeline pipeline_tanh_f32;
252
+ vk_pipeline pipeline_sigmoid_f32;
248
253
  vk_pipeline pipeline_diag_mask_inf_f32;
249
254
  vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
250
255
  vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
256
+ vk_pipeline pipeline_soft_max_back_f32;
251
257
  vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
252
258
  vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
259
+ vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
260
+ vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
253
261
  vk_pipeline pipeline_argsort_f32;
254
262
  vk_pipeline pipeline_sum_rows_f32;
263
+ vk_pipeline pipeline_argmax_f32;
264
+ vk_pipeline pipeline_count_equal_i32;
255
265
  vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
256
266
  vk_pipeline pipeline_timestep_embedding_f32;
257
267
  vk_pipeline pipeline_pool2d_f32;
258
268
  vk_pipeline pipeline_rwkv_wkv6_f32;
269
+ vk_pipeline pipeline_opt_step_adamw_f32;
259
270
 
260
271
  // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
261
272
  vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
@@ -493,6 +504,11 @@ struct vk_op_rope_push_constants {
493
504
  float corr_dims[2];
494
505
  float theta_scale;
495
506
  uint32_t has_ff;
507
+ uint32_t ne02;
508
+ uint32_t s1;
509
+ uint32_t s2;
510
+ int32_t sections[4];
511
+ uint32_t is_back;
496
512
  };
497
513
 
498
514
  struct vk_op_soft_max_push_constants {
@@ -1294,7 +1310,9 @@ static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk:
1294
1310
  static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
1295
1311
  vk_buffer buf;
1296
1312
  try {
1297
- if (device->uma) {
1313
+ if (device->prefer_host_memory) {
1314
+ buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
1315
+ } else if (device->uma) {
1298
1316
  // Fall back to host memory type
1299
1317
  buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
1300
1318
  } else {
@@ -1378,7 +1396,37 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
1378
1396
  return {64, 64};
1379
1397
  };
1380
1398
 
1381
- static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id) {
1399
+ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
1400
+
1401
+ uint32_t lut_size = 0;
1402
+ switch (src0_type) {
1403
+ case GGML_TYPE_IQ1_S:
1404
+ case GGML_TYPE_IQ1_M:
1405
+ lut_size = 2*2048;
1406
+ break;
1407
+ case GGML_TYPE_IQ2_XXS:
1408
+ lut_size = 8*256;
1409
+ break;
1410
+ case GGML_TYPE_IQ2_XS:
1411
+ lut_size = 8*512;
1412
+ break;
1413
+ case GGML_TYPE_IQ2_S:
1414
+ lut_size = 8*1024;
1415
+ break;
1416
+ case GGML_TYPE_IQ3_XXS:
1417
+ lut_size = 4*256;
1418
+ break;
1419
+ case GGML_TYPE_IQ3_S:
1420
+ lut_size = 4*512;
1421
+ break;
1422
+ case GGML_TYPE_IQ4_NL:
1423
+ case GGML_TYPE_IQ4_XS:
1424
+ lut_size = 4*16;
1425
+ break;
1426
+ default:
1427
+ break;
1428
+ }
1429
+
1382
1430
  // Needs to be kept up to date on shader changes
1383
1431
  const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
1384
1432
  const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
@@ -1388,13 +1436,20 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1388
1436
  const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
1389
1437
  const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
1390
1438
 
1391
- return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize;
1439
+ const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
1440
+ const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
1441
+
1442
+ VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
1443
+ "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported);
1444
+
1445
+ return supported;
1392
1446
  }
1393
1447
 
1394
1448
  static void ggml_vk_load_shaders(vk_device& device) {
1395
1449
  VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
1396
1450
 
1397
1451
  // some shaders have a minimum subgroup size
1452
+ const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
1398
1453
  const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
1399
1454
  const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
1400
1455
 
@@ -1457,13 +1512,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
1457
1512
  const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
1458
1513
  const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
1459
1514
 
1460
- l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
1461
- m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
1462
- s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
1515
+ l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
1516
+ m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
1517
+ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
1463
1518
 
1464
- l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
1465
- m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
1466
- s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
1519
+ l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
1520
+ m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
1521
+ s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
1467
1522
 
1468
1523
  l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
1469
1524
  m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
@@ -1472,62 +1527,32 @@ static void ggml_vk_load_shaders(vk_device& device) {
1472
1527
  m_align = 64;
1473
1528
  s_align = 32;
1474
1529
 
1475
- // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
1476
- // and tile sizes, this should handle 16KB, 32KB, and 48KB+.
1477
- // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
1478
- // But the numbers happen to work out for 32KB shared memory size that when using the medium
1479
- // size there's enough room for everything, and we assert for this.
1480
- uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
1481
- if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
1482
- l_warptile = m_warptile;
1483
- l_wg_denoms = m_wg_denoms;
1484
- shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
1485
- GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
1486
- }
1487
- if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1488
- // assert mul_mat_mat_id shaders will fit.
1489
- GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
1490
- }
1491
-
1492
- shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
1493
- if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
1494
- if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
1495
- l_warptile_mmq = m_warptile_mmq;
1496
- l_mmq_wg_denoms = m_mmq_wg_denoms;
1497
- } else {
1498
- l_warptile_mmq = s_warptile_mmq;
1499
- l_mmq_wg_denoms = s_mmq_wg_denoms;
1530
+ for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
1531
+ ggml_type t = (ggml_type)i;
1532
+ // Disable medium and large matrix multiplication if not enough shared memory is available
1533
+ // Check mmq warptiles as the largest configuration
1534
+ // Throw an error if not enough for any matrix multiplication is available
1535
+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) {
1536
+ std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
1537
+ throw std::runtime_error("Shared memory size too small for matrix multiplication.");
1538
+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) {
1539
+ device->mul_mat_m[i] = false;
1540
+ device->mul_mat_l[i] = false;
1541
+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) {
1542
+ device->mul_mat_l[i] = false;
1543
+ }
1544
+
1545
+ // Disable mul_mat_id if not enough shared memory is available
1546
+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) {
1547
+ device->mul_mat_id_s[i] = false;
1548
+ device->mul_mat_id_m[i] = false;
1549
+ device->mul_mat_id_l[i] = false;
1550
+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) {
1551
+ device->mul_mat_id_m[i] = false;
1552
+ device->mul_mat_id_l[i] = false;
1553
+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) {
1554
+ device->mul_mat_id_l[i] = false;
1500
1555
  }
1501
- shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
1502
- GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
1503
- }
1504
- if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1505
- // assert mul_mat_mat_id shaders will fit.
1506
- GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
1507
- }
1508
- // Disable medium and large matrix multiplication if not enough shared memory is available
1509
- // Check mmq warptiles as the largest configuration
1510
- // Throw an error if not enough for any matrix multiplication is available
1511
- if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) {
1512
- std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
1513
- throw std::runtime_error("Shared memory size too small for matrix multiplication.");
1514
- } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) {
1515
- device->mul_mat_m = false;
1516
- device->mul_mat_l = false;
1517
- } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) {
1518
- device->mul_mat_l = false;
1519
- }
1520
-
1521
- // Disable mul_mat_id if not enough shared memory is available
1522
- if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) {
1523
- device->mul_mat_id_s = false;
1524
- device->mul_mat_id_m = false;
1525
- device->mul_mat_id_l = false;
1526
- } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) {
1527
- device->mul_mat_id_m = false;
1528
- device->mul_mat_id_l = false;
1529
- } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) {
1530
- device->mul_mat_id_l = false;
1531
1556
  }
1532
1557
  }
1533
1558
 
@@ -1617,6 +1642,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1617
1642
  //CREATE_FA(GGML_TYPE_Q4_K, q4_k)
1618
1643
  //CREATE_FA(GGML_TYPE_Q5_K, q5_k)
1619
1644
  //CREATE_FA(GGML_TYPE_Q6_K, q6_k)
1645
+ //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s)
1646
+ //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m)
1620
1647
  //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs)
1621
1648
  //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs)
1622
1649
  //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
@@ -1651,6 +1678,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1651
1678
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1652
1679
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1653
1680
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1681
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1682
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1654
1683
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1655
1684
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1656
1685
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1670,6 +1699,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1670
1699
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1671
1700
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1672
1701
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1702
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1703
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1673
1704
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1674
1705
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1675
1706
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@@ -1684,119 +1715,124 @@ static void ggml_vk_load_shaders(vk_device& device) {
1684
1715
  #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
1685
1716
  if (device->coopmat_support) {
1686
1717
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1687
- #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1688
- if (device->mul_mat ## ID ## _l) \
1718
+ #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1719
+ if (device->mul_mat ## ID ## _l[TYPE]) \
1689
1720
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
1690
- if (device->mul_mat ## ID ## _m) \
1721
+ if (device->mul_mat ## ID ## _m[TYPE]) \
1691
1722
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
1692
- if (device->mul_mat ## ID ## _s) \
1723
+ if (device->mul_mat ## ID ## _s[TYPE]) \
1693
1724
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
1694
- if (device->mul_mat ## ID ## _l) \
1725
+ if (device->mul_mat ## ID ## _l[TYPE]) \
1695
1726
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
1696
- if (device->mul_mat ## ID ## _m) \
1727
+ if (device->mul_mat ## ID ## _m[TYPE]) \
1697
1728
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
1698
- if (device->mul_mat ## ID ## _s) \
1729
+ if (device->mul_mat ## ID ## _s[TYPE]) \
1699
1730
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
1700
1731
 
1701
1732
  // Create 2 variants, {f16,f32} accumulator
1702
- #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1733
+ #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1703
1734
  if (device->coopmat_acc_f16_support) { \
1704
- CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1735
+ CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1705
1736
  } \
1706
1737
  if (device->coopmat_acc_f32_support) { \
1707
- CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1738
+ CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1708
1739
  } \
1709
1740
 
1710
- CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1711
- CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1712
- CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1713
- CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1741
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1742
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1743
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1744
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1714
1745
 
1715
1746
  if (device->coopmat_acc_f16_support) {
1716
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1717
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1718
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1719
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1720
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1721
-
1722
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1723
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1724
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1725
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1726
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1727
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1728
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1729
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1730
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1731
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1732
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1733
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1747
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1748
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1749
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1750
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1751
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1752
+
1753
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1754
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1755
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1756
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1757
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1758
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1759
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1760
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1761
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1762
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1763
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1764
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1765
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1766
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1734
1767
  } else {
1735
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1736
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1737
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1738
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1739
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1740
-
1741
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1742
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1743
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1744
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1745
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1746
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1747
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1748
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1749
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1750
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1751
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1752
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1753
- }
1754
-
1755
- // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1756
- if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1757
- CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1758
- CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1759
- CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1760
-
1761
- if (device->coopmat_acc_f16_support) {
1762
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1763
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1764
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1765
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1766
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1767
-
1768
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1769
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1770
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1771
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1772
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1773
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1774
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1775
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1776
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1777
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1778
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1779
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1780
- } else {
1781
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1782
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1783
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1784
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1785
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1786
-
1787
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1788
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1789
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1790
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1791
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1792
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1793
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1794
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1795
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1796
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1797
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1798
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1799
- }
1768
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1769
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1770
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1771
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1772
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1773
+
1774
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1775
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1776
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1777
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1778
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1779
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1780
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1781
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1782
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1783
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1784
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1785
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1786
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1787
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1788
+ }
1789
+
1790
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1791
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1792
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1793
+
1794
+ if (device->coopmat_acc_f16_support) {
1795
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1796
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1797
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1798
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1799
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1800
+
1801
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1802
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1803
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1804
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1805
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1806
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1807
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1808
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1809
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1810
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1811
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1812
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1813
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1814
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1815
+ } else {
1816
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1817
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1818
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1819
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1820
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1821
+
1822
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1823
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1824
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1825
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1826
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1827
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1828
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1829
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1830
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1831
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1832
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1833
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1834
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1835
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1800
1836
  }
1801
1837
  #undef CREATE_MM2
1802
1838
  #undef CREATE_MM
@@ -1804,141 +1840,143 @@ static void ggml_vk_load_shaders(vk_device& device) {
1804
1840
  #endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
1805
1841
  if (device->fp16) {
1806
1842
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1807
- #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1808
- if (device->mul_mat ## ID ## _l) \
1843
+ #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1844
+ if (device->mul_mat ## ID ## _l[TYPE]) \
1809
1845
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1810
- if (device->mul_mat ## ID ## _m) \
1846
+ if (device->mul_mat ## ID ## _m[TYPE]) \
1811
1847
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1812
- if (device->mul_mat ## ID ## _s) \
1848
+ if (device->mul_mat ## ID ## _s[TYPE]) \
1813
1849
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1814
- if (device->mul_mat ## ID ## _l) \
1850
+ if (device->mul_mat ## ID ## _l[TYPE]) \
1815
1851
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1816
- if (device->mul_mat ## ID ## _m) \
1852
+ if (device->mul_mat ## ID ## _m[TYPE]) \
1817
1853
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1818
- if (device->mul_mat ## ID ## _s) \
1854
+ if (device->mul_mat ## ID ## _s[TYPE]) \
1819
1855
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1820
1856
 
1821
1857
  // Create 2 variants, {f16,f32} accumulator
1822
- #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1823
- CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1824
- CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1825
-
1826
- CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1827
- CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1828
- CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1829
- CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1830
-
1831
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1832
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1833
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1834
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1835
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1836
-
1837
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1838
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1839
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1840
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1841
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1842
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1843
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1844
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1845
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1846
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1847
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1848
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1849
-
1850
- // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1851
- if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1852
- CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1853
- CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1854
- CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1855
-
1856
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1857
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1858
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1859
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1860
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1861
-
1862
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1863
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1864
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1865
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1866
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1867
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1868
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1869
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1870
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1871
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1872
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1873
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1874
- }
1858
+ #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1859
+ CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1860
+ CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1861
+
1862
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1863
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1864
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1865
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1866
+
1867
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1868
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1869
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1870
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1871
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1872
+
1873
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1874
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1875
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1876
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1877
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1878
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1879
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1880
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1881
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1882
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1883
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1884
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1885
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1886
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1887
+
1888
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1889
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1890
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1891
+
1892
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1893
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1894
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1895
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1896
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1897
+
1898
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1899
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1900
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1901
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1902
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1903
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1904
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1905
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1906
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1907
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1908
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1909
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1910
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1911
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1875
1912
  #undef CREATE_MM2
1876
1913
  #undef CREATE_MM
1877
1914
  } else {
1878
1915
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1879
- #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1880
- if (device->mul_mat ## ID ## _l) \
1916
+ #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1917
+ if (device->mul_mat ## ID ## _l[TYPE]) \
1881
1918
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1882
- if (device->mul_mat ## ID ## _m) \
1919
+ if (device->mul_mat ## ID ## _m[TYPE]) \
1883
1920
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1884
- if (device->mul_mat ## ID ## _s) \
1921
+ if (device->mul_mat ## ID ## _s[TYPE]) \
1885
1922
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1886
- if (device->mul_mat ## ID ## _l) \
1923
+ if (device->mul_mat ## ID ## _l[TYPE]) \
1887
1924
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1888
- if (device->mul_mat ## ID ## _m) \
1925
+ if (device->mul_mat ## ID ## _m[TYPE]) \
1889
1926
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1890
- if (device->mul_mat ## ID ## _s) \
1927
+ if (device->mul_mat ## ID ## _s[TYPE]) \
1891
1928
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1892
1929
 
1893
- CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1894
- CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1895
- CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1896
- CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1897
-
1898
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1899
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1900
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1901
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1902
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1903
-
1904
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1905
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1906
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1907
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1908
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1909
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1910
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1911
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1912
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1913
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1914
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1915
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1916
-
1917
- // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1918
- if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1919
- CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1920
- CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1921
- CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1922
-
1923
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1924
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1925
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1926
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1927
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1928
-
1929
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1930
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1931
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1932
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1933
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1934
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1935
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1936
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1937
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1938
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1939
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1940
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1941
- }
1930
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1931
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1932
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1933
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1934
+
1935
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1936
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1937
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1938
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1939
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1940
+
1941
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1942
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1943
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1944
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1945
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1946
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1947
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1948
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1949
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1950
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1951
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1952
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1953
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1954
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1955
+
1956
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1957
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1958
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1959
+
1960
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1961
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1962
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1963
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1964
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1965
+
1966
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1967
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1968
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1969
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1970
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1971
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1972
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1973
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1974
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1975
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1976
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1977
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1978
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1979
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1942
1980
  #undef CREATE_MM
1943
1981
  }
1944
1982
 
@@ -1954,6 +1992,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1954
1992
  }
1955
1993
  } else if (device->vendor_id == VK_VENDOR_ID_INTEL)
1956
1994
  rm_stdq = 2;
1995
+ uint32_t rm_iq = 2 * rm_kq;
1957
1996
 
1958
1997
  for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
1959
1998
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
@@ -1968,13 +2007,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
1968
2007
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1969
2008
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1970
2009
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1971
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1972
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1973
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1974
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1975
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1976
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1977
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
2010
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2011
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2012
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2013
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2014
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2015
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2016
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2017
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2018
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
1978
2019
 
1979
2020
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
1980
2021
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
@@ -1988,13 +2029,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
1988
2029
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1989
2030
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1990
2031
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1991
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1992
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1993
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1994
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1995
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1996
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1997
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
2032
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2033
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2034
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2035
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2036
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2037
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2038
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2039
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2040
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
1998
2041
  }
1999
2042
 
2000
2043
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -2009,13 +2052,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
2009
2052
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2010
2053
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2011
2054
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2012
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2013
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2014
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2015
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2016
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2017
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2018
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
2055
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2056
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2057
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2058
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2059
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2060
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2061
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2062
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2063
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2019
2064
 
2020
2065
  // dequant shaders
2021
2066
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -2029,6 +2074,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2029
2074
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2030
2075
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
2031
2076
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
2077
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S], "dequant_iq1_s", dequant_iq1_s_len, dequant_iq1_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2078
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M], "dequant_iq1_m", dequant_iq1_m_len, dequant_iq1_m_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2032
2079
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2033
2080
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2034
2081
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
@@ -2045,6 +2092,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2045
2092
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2046
2093
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2047
2094
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2095
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2096
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2048
2097
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2049
2098
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2050
2099
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -2060,6 +2109,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2060
2109
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2061
2110
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2062
2111
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2112
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2113
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2063
2114
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2064
2115
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2065
2116
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -2076,6 +2127,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2076
2127
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2077
2128
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2078
2129
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2130
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2079
2131
 
2080
2132
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2081
2133
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -2106,6 +2158,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2106
2158
 
2107
2159
  ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2108
2160
 
2161
+ ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2162
+ ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2109
2163
  ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2110
2164
  ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2111
2165
  ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
@@ -2128,13 +2182,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
2128
2182
  ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2129
2183
 
2130
2184
  ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2185
+ ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2131
2186
 
2132
2187
  ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2133
2188
  ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2134
2189
  ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2190
+ ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2135
2191
  ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2136
2192
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2137
2193
  ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2194
+ ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2138
2195
 
2139
2196
  ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
2140
2197
 
@@ -2142,22 +2199,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
2142
2199
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2143
2200
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2144
2201
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2202
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2145
2203
 
2146
2204
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2147
2205
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2206
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2207
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2148
2208
 
2149
2209
  if (device->float_controls_rte_fp16) {
2150
2210
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2151
2211
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2212
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2213
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2152
2214
  } else {
2153
2215
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2154
2216
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2217
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2218
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2155
2219
  }
2156
2220
 
2157
2221
  ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
2158
2222
 
2223
+ ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2224
+
2159
2225
  ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2160
2226
 
2227
+ ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2228
+
2161
2229
  ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2162
2230
  if (device->float_controls_rte_fp16) {
2163
2231
  ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
@@ -2171,6 +2239,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2171
2239
 
2172
2240
  ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2173
2241
 
2242
+ ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2243
+
2174
2244
  for (auto &c : compiles) {
2175
2245
  c.wait();
2176
2246
  }
@@ -2206,6 +2276,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2206
2276
  device->physical_device = physical_devices[dev_num];
2207
2277
  const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
2208
2278
 
2279
+ const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
2280
+ device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
2281
+
2209
2282
  bool fp16_storage = false;
2210
2283
  bool fp16_compute = false;
2211
2284
  bool maintenance4_support = false;
@@ -2623,34 +2696,36 @@ static vk_device ggml_vk_get_device(size_t idx) {
2623
2696
 
2624
2697
  // Shaders
2625
2698
  // Disable matmul tile sizes early if performance low or not supported
2626
- switch (device->vendor_id) {
2699
+ for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
2700
+ switch (device->vendor_id) {
2627
2701
  #ifndef GGML_VULKAN_RUN_TESTS
2628
- case VK_VENDOR_ID_AMD:
2629
- case VK_VENDOR_ID_INTEL:
2630
- device->mul_mat_l = false;
2631
- device->mul_mat_m = true;
2632
- device->mul_mat_s = true;
2633
- device->mul_mat_id_l = false;
2634
- device->mul_mat_id_m = true;
2635
- device->mul_mat_id_s = true;
2636
- break;
2637
- case VK_VENDOR_ID_APPLE:
2638
- device->mul_mat_l = false;
2639
- device->mul_mat_m = true;
2640
- device->mul_mat_s = false;
2641
- device->mul_mat_id_l = false;
2642
- device->mul_mat_id_m = true;
2643
- device->mul_mat_id_s = false;
2644
- break;
2702
+ case VK_VENDOR_ID_AMD:
2703
+ case VK_VENDOR_ID_INTEL:
2704
+ device->mul_mat_l[i] = false;
2705
+ device->mul_mat_m[i] = true;
2706
+ device->mul_mat_s[i] = true;
2707
+ device->mul_mat_id_l[i] = false;
2708
+ device->mul_mat_id_m[i] = true;
2709
+ device->mul_mat_id_s[i] = true;
2710
+ break;
2711
+ case VK_VENDOR_ID_APPLE:
2712
+ device->mul_mat_l[i] = false;
2713
+ device->mul_mat_m[i] = true;
2714
+ device->mul_mat_s[i] = false;
2715
+ device->mul_mat_id_l[i] = false;
2716
+ device->mul_mat_id_m[i] = true;
2717
+ device->mul_mat_id_s[i] = false;
2718
+ break;
2645
2719
  #endif
2646
- default:
2647
- device->mul_mat_l = true;
2648
- device->mul_mat_m = true;
2649
- device->mul_mat_s = true;
2650
- device->mul_mat_id_l = true;
2651
- device->mul_mat_id_m = true;
2652
- device->mul_mat_id_s = true;
2653
- break;
2720
+ default:
2721
+ device->mul_mat_l[i] = true;
2722
+ device->mul_mat_m[i] = true;
2723
+ device->mul_mat_s[i] = true;
2724
+ device->mul_mat_id_l[i] = true;
2725
+ device->mul_mat_id_m[i] = true;
2726
+ device->mul_mat_id_s[i] = true;
2727
+ break;
2728
+ }
2654
2729
  }
2655
2730
 
2656
2731
  ggml_vk_load_shaders(device);
@@ -2780,8 +2855,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2780
2855
  std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
2781
2856
 
2782
2857
  std::string device_name = props2.properties.deviceName.data();
2783
- GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n",
2784
- idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str());
2858
+ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
2859
+ idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
2860
+ props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
2785
2861
 
2786
2862
  if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
2787
2863
  GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@@ -2791,14 +2867,12 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2791
2867
  static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
2792
2868
  static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
2793
2869
 
2794
- void ggml_vk_instance_init() {
2870
+ static void ggml_vk_instance_init() {
2795
2871
  if (vk_instance_initialized) {
2796
2872
  return;
2797
2873
  }
2798
2874
  VK_LOG_DEBUG("ggml_vk_instance_init()");
2799
2875
 
2800
- vk_instance_initialized = true;
2801
-
2802
2876
  uint32_t api_version = vk::enumerateInstanceVersion();
2803
2877
 
2804
2878
  if (api_version < VK_API_VERSION_1_2) {
@@ -2849,6 +2923,7 @@ void ggml_vk_instance_init() {
2849
2923
  GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
2850
2924
  }
2851
2925
  vk_instance.instance = vk::createInstance(instance_create_info);
2926
+ vk_instance_initialized = true;
2852
2927
 
2853
2928
  size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
2854
2929
 
@@ -2873,7 +2948,7 @@ void ggml_vk_instance_init() {
2873
2948
  // Make sure at least one device exists
2874
2949
  if (devices.empty()) {
2875
2950
  std::cerr << "ggml_vulkan: Error: No devices found." << std::endl;
2876
- GGML_ABORT("fatal error");
2951
+ return;
2877
2952
  }
2878
2953
 
2879
2954
  // Default to using all dedicated GPUs
@@ -3007,6 +3082,8 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
3007
3082
  case GGML_TYPE_Q4_K:
3008
3083
  case GGML_TYPE_Q5_K:
3009
3084
  case GGML_TYPE_Q6_K:
3085
+ case GGML_TYPE_IQ1_S:
3086
+ case GGML_TYPE_IQ1_M:
3010
3087
  case GGML_TYPE_IQ2_XXS:
3011
3088
  case GGML_TYPE_IQ2_XS:
3012
3089
  case GGML_TYPE_IQ2_S:
@@ -3061,6 +3138,8 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
3061
3138
  case GGML_TYPE_Q4_K:
3062
3139
  case GGML_TYPE_Q5_K:
3063
3140
  case GGML_TYPE_Q6_K:
3141
+ case GGML_TYPE_IQ1_S:
3142
+ case GGML_TYPE_IQ1_M:
3064
3143
  case GGML_TYPE_IQ2_XXS:
3065
3144
  case GGML_TYPE_IQ2_XS:
3066
3145
  case GGML_TYPE_IQ2_S:
@@ -3098,6 +3177,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
3098
3177
  case GGML_TYPE_Q4_K:
3099
3178
  case GGML_TYPE_Q5_K:
3100
3179
  case GGML_TYPE_Q6_K:
3180
+ case GGML_TYPE_IQ1_S:
3181
+ case GGML_TYPE_IQ1_M:
3101
3182
  case GGML_TYPE_IQ2_XXS:
3102
3183
  case GGML_TYPE_IQ2_XS:
3103
3184
  case GGML_TYPE_IQ2_S:
@@ -3147,6 +3228,8 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
3147
3228
  case GGML_TYPE_Q4_K:
3148
3229
  case GGML_TYPE_Q5_K:
3149
3230
  case GGML_TYPE_Q6_K:
3231
+ case GGML_TYPE_IQ1_S:
3232
+ case GGML_TYPE_IQ1_M:
3150
3233
  case GGML_TYPE_IQ2_XXS:
3151
3234
  case GGML_TYPE_IQ2_XS:
3152
3235
  case GGML_TYPE_IQ2_S:
@@ -3179,6 +3262,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
3179
3262
  case GGML_TYPE_Q4_K:
3180
3263
  case GGML_TYPE_Q5_K:
3181
3264
  case GGML_TYPE_Q6_K:
3265
+ case GGML_TYPE_IQ1_S:
3266
+ case GGML_TYPE_IQ1_M:
3182
3267
  case GGML_TYPE_IQ2_XXS:
3183
3268
  case GGML_TYPE_IQ2_XS:
3184
3269
  case GGML_TYPE_IQ2_S:
@@ -3721,6 +3806,12 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
3721
3806
  }
3722
3807
  }
3723
3808
 
3809
+ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
3810
+ VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
3811
+
3812
+ ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
3813
+ }
3814
+
3724
3815
  static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
3725
3816
  VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
3726
3817
 
@@ -3755,31 +3846,31 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
3755
3846
  return split_k;
3756
3847
  }
3757
3848
 
3758
- static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
3759
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
3849
+ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
3850
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
3760
3851
 
3761
3852
  if (ctx->device->coopmat2) {
3762
- if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
3853
+ if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
3763
3854
  return aligned ? mmp->a_l : mmp->l;
3764
3855
  }
3765
- if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) {
3856
+ if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
3766
3857
  return aligned ? mmp->a_m : mmp->m;
3767
3858
  }
3768
3859
  return aligned ? mmp->a_s : mmp->s;
3769
3860
  }
3770
3861
 
3771
- if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) {
3862
+ if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
3772
3863
  return aligned ? mmp->a_s : mmp->s;
3773
3864
  }
3774
- if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) {
3865
+ if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
3775
3866
  return aligned ? mmp->a_m : mmp->m;
3776
3867
  }
3777
3868
  return aligned ? mmp->a_l : mmp->l;
3778
3869
  }
3779
3870
 
3780
- static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
3781
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
3782
- return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align;
3871
+ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
3872
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
3873
+ return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
3783
3874
  }
3784
3875
 
3785
3876
  static void ggml_vk_matmul(
@@ -3806,31 +3897,31 @@ static void ggml_vk_matmul(
3806
3897
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
3807
3898
  }
3808
3899
 
3809
- static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
3810
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
3900
+ static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
3901
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
3811
3902
 
3812
3903
  if (ctx->device->coopmat2) {
3813
- if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) {
3904
+ if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
3814
3905
  return aligned ? mmp->a_l : mmp->l;
3815
3906
  }
3816
- if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) {
3907
+ if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
3817
3908
  return aligned ? mmp->a_m : mmp->m;
3818
3909
  }
3819
3910
  return aligned ? mmp->a_s : mmp->s;
3820
3911
  }
3821
3912
 
3822
- if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
3913
+ if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
3823
3914
  return aligned ? mmp->a_s : mmp->s;
3824
3915
  }
3825
- if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) {
3916
+ if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
3826
3917
  return aligned ? mmp->a_m : mmp->m;
3827
3918
  }
3828
3919
  return aligned ? mmp->a_l : mmp->l;
3829
3920
  }
3830
3921
 
3831
- static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
3832
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
3833
- return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align;
3922
+ static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
3923
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
3924
+ return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
3834
3925
  }
3835
3926
 
3836
3927
  static void ggml_vk_matmul_id(
@@ -4011,10 +4102,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4011
4102
  const int y_ne = ne11 * ne10;
4012
4103
  const int d_ne = ne11 * ne01;
4013
4104
 
4014
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
4105
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4015
4106
  const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
4016
4107
 
4017
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
4108
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4018
4109
 
4019
4110
  const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
4020
4111
 
@@ -4102,7 +4193,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4102
4193
  }
4103
4194
  if (qy_needs_dequant) {
4104
4195
  d_Y = ctx->prealloc_y;
4105
- GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03);
4196
+ GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
4106
4197
  } else {
4107
4198
  d_Y = d_Qy;
4108
4199
  y_buf_offset = qy_buf_offset;
@@ -4593,10 +4684,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4593
4684
  const uint64_t y_ne = ne11 * ne10;
4594
4685
  const uint64_t d_ne = ne21 * ne20;
4595
4686
 
4596
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1));
4687
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4597
4688
  const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
4598
4689
 
4599
- vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned);
4690
+ vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4600
4691
 
4601
4692
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
4602
4693
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
@@ -4679,7 +4770,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4679
4770
  }
4680
4771
  if (qy_needs_dequant) {
4681
4772
  d_Y = ctx->prealloc_y;
4682
- GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03);
4773
+ GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
4683
4774
  } else {
4684
4775
  d_Y = d_Qy;
4685
4776
  y_buf_offset = qy_buf_offset;
@@ -5127,6 +5218,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5127
5218
  return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
5128
5219
  }
5129
5220
  return nullptr;
5221
+ case GGML_OP_SUB:
5222
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5223
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32;
5224
+ }
5225
+ return nullptr;
5130
5226
  case GGML_OP_MUL:
5131
5227
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5132
5228
  return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
@@ -5188,10 +5284,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5188
5284
  return ctx->device->pipeline_repeat_f32;
5189
5285
  }
5190
5286
  return nullptr;
5287
+ case GGML_OP_REPEAT_BACK:
5288
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5289
+ return ctx->device->pipeline_repeat_back_f32;
5290
+ }
5291
+ return nullptr;
5191
5292
  case GGML_OP_CPY:
5192
5293
  case GGML_OP_CONT:
5193
5294
  case GGML_OP_DUP:
5194
5295
  return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
5296
+ case GGML_OP_SILU_BACK:
5297
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5298
+ return ctx->device->pipeline_silu_back_f32;
5299
+ }
5300
+ return nullptr;
5195
5301
  case GGML_OP_NORM:
5196
5302
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5197
5303
  return ctx->device->pipeline_norm_f32;
@@ -5207,6 +5313,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5207
5313
  return ctx->device->pipeline_rms_norm_f32;
5208
5314
  }
5209
5315
  return nullptr;
5316
+ case GGML_OP_RMS_NORM_BACK:
5317
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5318
+ return ctx->device->pipeline_rms_norm_back_f32;
5319
+ }
5320
+ return nullptr;
5210
5321
  case GGML_OP_UNARY:
5211
5322
  switch (ggml_get_unary_op(dst)) {
5212
5323
  case GGML_UNARY_OP_SILU:
@@ -5234,6 +5345,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5234
5345
  return ctx->device->pipeline_tanh_f32;
5235
5346
  }
5236
5347
  break;
5348
+ case GGML_UNARY_OP_SIGMOID:
5349
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5350
+ return ctx->device->pipeline_sigmoid_f32;
5351
+ }
5352
+ break;
5237
5353
  default:
5238
5354
  break;
5239
5355
  }
@@ -5253,10 +5369,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5253
5369
  return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
5254
5370
  }
5255
5371
  return nullptr;
5372
+ case GGML_OP_SOFT_MAX_BACK:
5373
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5374
+ return ctx->device->pipeline_soft_max_back_f32;
5375
+ }
5376
+ return nullptr;
5256
5377
  case GGML_OP_ROPE:
5378
+ case GGML_OP_ROPE_BACK:
5257
5379
  {
5258
5380
  const int mode = ((const int32_t *) dst->op_params)[2];
5259
5381
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5382
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5383
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5260
5384
 
5261
5385
  if (is_neox) {
5262
5386
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
@@ -5265,6 +5389,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5265
5389
  if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
5266
5390
  return ctx->device->pipeline_rope_neox_f16;
5267
5391
  }
5392
+ } else if (is_mrope && !is_vision) {
5393
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5394
+ return ctx->device->pipeline_rope_multi_f32;
5395
+ }
5396
+ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
5397
+ return ctx->device->pipeline_rope_multi_f16;
5398
+ }
5399
+ } else if (is_vision) {
5400
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5401
+ return ctx->device->pipeline_rope_vision_f32;
5402
+ }
5403
+ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
5404
+ return ctx->device->pipeline_rope_vision_f16;
5405
+ }
5268
5406
  } else {
5269
5407
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5270
5408
  return ctx->device->pipeline_rope_norm_f32;
@@ -5280,11 +5418,22 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5280
5418
  return ctx->device->pipeline_argsort_f32;
5281
5419
  }
5282
5420
  return nullptr;
5421
+ case GGML_OP_SUM:
5283
5422
  case GGML_OP_SUM_ROWS:
5284
5423
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5285
5424
  return ctx->device->pipeline_sum_rows_f32;
5286
5425
  }
5287
5426
  return nullptr;
5427
+ case GGML_OP_ARGMAX:
5428
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
5429
+ return ctx->device->pipeline_argmax_f32;
5430
+ }
5431
+ return nullptr;
5432
+ case GGML_OP_COUNT_EQUAL:
5433
+ if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
5434
+ return ctx->device->pipeline_count_equal_i32;
5435
+ }
5436
+ return nullptr;
5288
5437
  case GGML_OP_IM2COL:
5289
5438
  if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5290
5439
  return ctx->device->pipeline_im2col_f32;
@@ -5308,6 +5457,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5308
5457
  return ctx->device->pipeline_rwkv_wkv6_f32;
5309
5458
  }
5310
5459
  return nullptr;
5460
+ case GGML_OP_OPT_STEP_ADAMW:
5461
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5462
+ return ctx->device->pipeline_opt_step_adamw_f32;
5463
+ }
5464
+ return nullptr;
5311
5465
  case GGML_OP_LEAKY_RELU:
5312
5466
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5313
5467
  return ctx->device->pipeline_leaky_relu_f32;
@@ -5325,6 +5479,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
5325
5479
  case GGML_OP_CPY:
5326
5480
  case GGML_OP_GET_ROWS:
5327
5481
  case GGML_OP_ADD:
5482
+ case GGML_OP_SUB:
5328
5483
  case GGML_OP_MUL:
5329
5484
  case GGML_OP_DIV:
5330
5485
  case GGML_OP_CONCAT:
@@ -5335,6 +5490,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
5335
5490
  case GGML_OP_CLAMP:
5336
5491
  case GGML_OP_PAD:
5337
5492
  case GGML_OP_REPEAT:
5493
+ case GGML_OP_REPEAT_BACK:
5494
+ case GGML_OP_ROPE:
5338
5495
  return true;
5339
5496
  default:
5340
5497
  return false;
@@ -5546,8 +5703,11 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5546
5703
  switch (op) {
5547
5704
  case GGML_OP_NORM:
5548
5705
  case GGML_OP_RMS_NORM:
5706
+ case GGML_OP_RMS_NORM_BACK:
5549
5707
  case GGML_OP_SOFT_MAX:
5708
+ case GGML_OP_SOFT_MAX_BACK:
5550
5709
  case GGML_OP_SUM_ROWS:
5710
+ case GGML_OP_ARGMAX:
5551
5711
  {
5552
5712
  const uint32_t nr = ggml_nrows(src0);
5553
5713
  if (nr > 262144) {
@@ -5558,6 +5718,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5558
5718
  elements = { nr, 1, 1 };
5559
5719
  }
5560
5720
  } break;
5721
+ case GGML_OP_SUM:
5722
+ // We use GGML_OP_SUM_ROWS with 1 row.
5723
+ elements = { 1, 1, 1 };
5724
+ break;
5561
5725
  case GGML_OP_GROUP_NORM:
5562
5726
  {
5563
5727
  const uint32_t num_groups = dst->op_params[0];
@@ -5565,6 +5729,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5565
5729
  } break;
5566
5730
  case GGML_OP_DIAG_MASK_INF:
5567
5731
  case GGML_OP_ROPE:
5732
+ case GGML_OP_ROPE_BACK:
5568
5733
  elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
5569
5734
  break;
5570
5735
  case GGML_OP_GET_ROWS:
@@ -5604,6 +5769,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5604
5769
  elements = { N * OC * OH * OW, 1, 1};
5605
5770
  } break;
5606
5771
  case GGML_OP_ADD:
5772
+ case GGML_OP_SUB:
5607
5773
  case GGML_OP_DIV:
5608
5774
  case GGML_OP_MUL:
5609
5775
  case GGML_OP_SCALE:
@@ -5613,6 +5779,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5613
5779
  case GGML_OP_CLAMP:
5614
5780
  case GGML_OP_PAD:
5615
5781
  case GGML_OP_REPEAT:
5782
+ case GGML_OP_REPEAT_BACK:
5616
5783
  case GGML_OP_CPY:
5617
5784
  case GGML_OP_CONCAT:
5618
5785
  case GGML_OP_UPSCALE:
@@ -5658,7 +5825,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5658
5825
 
5659
5826
  ggml_vk_sync_buffers(subctx);
5660
5827
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5661
- } else if (op == GGML_OP_ROPE) {
5828
+ } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
5662
5829
  // Empty src2 is possible in rope, but the shader needs a buffer
5663
5830
  vk_subbuffer subbuf_z;
5664
5831
  if (use_src2) {
@@ -5673,6 +5840,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5673
5840
  // im2col uses only src1 and dst buffers
5674
5841
  ggml_vk_sync_buffers(subctx);
5675
5842
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5843
+ } else if (op == GGML_OP_COUNT_EQUAL) {
5844
+ ggml_vk_sync_buffers(subctx);
5845
+ // count_equal assumes that destination buffer is initialized with zeroes
5846
+ ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
5847
+ ggml_vk_sync_buffers(subctx);
5848
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5676
5849
  } else if (use_src2) {
5677
5850
  ggml_vk_sync_buffers(subctx);
5678
5851
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
@@ -5735,6 +5908,21 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
5735
5908
  }, dryrun);
5736
5909
  }
5737
5910
 
5911
+ static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
5912
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
5913
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
5914
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
5915
+
5916
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, {
5917
+ (uint32_t)ggml_nelements(src0),
5918
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
5919
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
5920
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
5921
+ 0,
5922
+ 0.0f, 0.0f, 0,
5923
+ }, dryrun);
5924
+ }
5925
+
5738
5926
  static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
5739
5927
  const uint32_t src0_type_size = ggml_type_size(src0->type);
5740
5928
  const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -5893,6 +6081,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
5893
6081
  );
5894
6082
  }
5895
6083
 
6084
+ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
6085
+ const ggml_tensor * x = dst->src[0];
6086
+ const ggml_tensor * g = dst->src[1];
6087
+ const ggml_tensor * gm = dst->src[2];
6088
+ const ggml_tensor * gv = dst->src[3];
6089
+ const ggml_tensor * p = dst->src[4];
6090
+
6091
+ GGML_ASSERT(x->type == GGML_TYPE_F32);
6092
+ GGML_ASSERT(g->type == GGML_TYPE_F32);
6093
+ GGML_ASSERT(gm->type == GGML_TYPE_F32);
6094
+ GGML_ASSERT(gv->type == GGML_TYPE_F32);
6095
+ GGML_ASSERT(p->type == GGML_TYPE_F32);
6096
+ GGML_ASSERT(dst->buffer != nullptr);
6097
+ GGML_ASSERT(ggml_is_contiguous(x));
6098
+ GGML_ASSERT(ggml_is_contiguous(g));
6099
+ GGML_ASSERT(ggml_is_contiguous(gm));
6100
+ GGML_ASSERT(ggml_is_contiguous(gv));
6101
+ GGML_ASSERT(ggml_is_contiguous(p));
6102
+ GGML_ASSERT(ggml_are_same_shape(x, g));
6103
+ GGML_ASSERT(ggml_are_same_shape(x, gm));
6104
+ GGML_ASSERT(ggml_are_same_shape(x, gv));
6105
+ GGML_ASSERT(ggml_nelements(p) == 7);
6106
+
6107
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
6108
+ GGML_ASSERT(pipeline != nullptr);
6109
+
6110
+ if (dryrun) {
6111
+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
6112
+ return;
6113
+ }
6114
+
6115
+ ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context;
6116
+ ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context;
6117
+ ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context;
6118
+ ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
6119
+ ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
6120
+
6121
+ ggml_vk_sync_buffers(subctx);
6122
+
6123
+ vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
6124
+ size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
6125
+ bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
6126
+
6127
+ if (ctx->device->uma) {
6128
+ ggml_vk_host_get(ctx->device, x->data, d_X, x_offset);
6129
+ ggml_vk_host_get(ctx->device, g->data, d_G, g_offset);
6130
+ ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset);
6131
+ ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset);
6132
+ ggml_vk_host_get(ctx->device, p->data, d_P, p_offset);
6133
+
6134
+ X_uma = d_X != nullptr;
6135
+ G_uma = d_G != nullptr;
6136
+ GM_uma = d_GM != nullptr;
6137
+ GV_uma = d_GV != nullptr;
6138
+ P_uma = d_P != nullptr;
6139
+ }
6140
+
6141
+ if (!X_uma) {
6142
+ d_X = x_buf_ctx->dev_buffer;
6143
+ x_offset = vk_tensor_offset(x) + x->view_offs;
6144
+ }
6145
+ if (!G_uma) {
6146
+ d_G = g_buf_ctx->dev_buffer;
6147
+ g_offset = vk_tensor_offset(g) + g->view_offs;
6148
+ }
6149
+ if (!GM_uma) {
6150
+ d_GM = gm_buf_ctx->dev_buffer;
6151
+ gm_offset = vk_tensor_offset(gm) + gm->view_offs;
6152
+ }
6153
+ if (!GV_uma) {
6154
+ d_GV = gv_buf_ctx->dev_buffer;
6155
+ gv_offset = vk_tensor_offset(gv) + gv->view_offs;
6156
+ }
6157
+ if (!P_uma) {
6158
+ d_P = p_buf_ctx->dev_buffer;
6159
+ p_offset = vk_tensor_offset(p) + p->view_offs;
6160
+ }
6161
+
6162
+ const uint64_t x_size = ggml_nbytes(x);
6163
+ const uint64_t g_size = ggml_nbytes(g);
6164
+ const uint64_t gm_size = ggml_nbytes(gm);
6165
+ const uint64_t gv_size = ggml_nbytes(gv);
6166
+ const uint64_t p_size = ggml_nbytes(p);
6167
+
6168
+ std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
6169
+
6170
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6171
+ vk_subbuffer{ d_X, x_offset, x_size },
6172
+ vk_subbuffer{ d_G, g_offset, g_size },
6173
+ vk_subbuffer{ d_GM, gm_offset, gm_size },
6174
+ vk_subbuffer{ d_GV, gv_offset, gv_size },
6175
+ vk_subbuffer{ d_P, p_offset, p_size },
6176
+ }, sizeof(vk_op_push_constants), &pc, elements);
6177
+ }
6178
+
6179
+ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
6180
+ const size_t n = ggml_nelements(dst->src[0]);
6181
+
6182
+ ggml_vk_op_f32_opt_step_adamw(
6183
+ ctx, subctx, dst,
6184
+ { (uint32_t)n, 0, 0.0f, 0.0f },
6185
+ dryrun
6186
+ );
6187
+ }
6188
+
5896
6189
  static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
5897
6190
  int * op_params = (int *)dst->op_params;
5898
6191
 
@@ -6026,6 +6319,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
6026
6319
  }, dryrun);
6027
6320
  }
6028
6321
 
6322
+ static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6323
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
6324
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
6325
+
6326
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
6327
+ (uint32_t)ggml_nelements(dst),
6328
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
6329
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
6330
+ 0,
6331
+ 0.0f, 0.0f,
6332
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
6333
+ }, dryrun);
6334
+ }
6335
+
6029
6336
  static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6030
6337
  const uint32_t src0_type_size = ggml_type_size(src0->type);
6031
6338
  const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -6040,6 +6347,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
6040
6347
  }, dryrun);
6041
6348
  }
6042
6349
 
6350
+ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6351
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6352
+ }
6353
+
6043
6354
  static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6044
6355
  float * op_params = (float *)dst->op_params;
6045
6356
 
@@ -6062,6 +6373,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
6062
6373
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6063
6374
  }
6064
6375
 
6376
+ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6377
+ float * op_params = (float *)dst->op_params;
6378
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6379
+ }
6380
+
6065
6381
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6066
6382
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6067
6383
  }
@@ -6097,9 +6413,14 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
6097
6413
  }, dryrun);
6098
6414
  }
6099
6415
 
6100
- static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
6416
+ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6417
+ float * op_params = (float *)dst->op_params;
6418
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun);
6419
+ }
6420
+
6421
+ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
6101
6422
  const int n_dims = ((int32_t *) dst->op_params)[1];
6102
- // const int mode = ((int32_t *) dst->op_params)[2];
6423
+ const int mode = ((int32_t *) dst->op_params)[2];
6103
6424
  // const int n_ctx = ((int32_t *) dst->op_params)[3];
6104
6425
  const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
6105
6426
  const float freq_base = ((float *) dst->op_params)[5];
@@ -6108,16 +6429,24 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
6108
6429
  const float attn_factor = ((float *) dst->op_params)[8];
6109
6430
  const float beta_fast = ((float *) dst->op_params)[9];
6110
6431
  const float beta_slow = ((float *) dst->op_params)[10];
6432
+ int sections[4] {};
6433
+ if (mode & GGML_ROPE_TYPE_MROPE) {
6434
+ memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
6435
+ }
6111
6436
 
6112
6437
  float corr_dims[2];
6113
6438
  ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
6114
6439
 
6115
6440
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
6116
6441
 
6442
+ uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
6443
+ uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
6444
+
6117
6445
  ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
6118
6446
  (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
6119
6447
  freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
6120
- src2 != nullptr,
6448
+ src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
6449
+ sections[0], sections[1], sections[2], sections[3], backprop
6121
6450
  }, dryrun);
6122
6451
  }
6123
6452
 
@@ -6140,10 +6469,22 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
6140
6469
  }, dryrun);
6141
6470
  }
6142
6471
 
6472
+ static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6473
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6474
+ }
6475
+
6143
6476
  static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6144
6477
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
6145
6478
  }
6146
6479
 
6480
+ static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6481
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
6482
+ }
6483
+
6484
+ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6485
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6486
+ }
6487
+
6147
6488
  static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6148
6489
  const int32_t s0 = dst->op_params[0];
6149
6490
  const int32_t s1 = dst->op_params[1];
@@ -7002,15 +7343,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7002
7343
  case GGML_UNARY_OP_GELU_QUICK:
7003
7344
  case GGML_UNARY_OP_RELU:
7004
7345
  case GGML_UNARY_OP_TANH:
7346
+ case GGML_UNARY_OP_SIGMOID:
7005
7347
  break;
7006
7348
  default:
7007
7349
  return false;
7008
7350
  }
7009
7351
  break;
7010
7352
  case GGML_OP_REPEAT:
7353
+ case GGML_OP_REPEAT_BACK:
7011
7354
  case GGML_OP_GET_ROWS:
7012
7355
  case GGML_OP_ADD:
7013
7356
  case GGML_OP_ACC:
7357
+ case GGML_OP_SUB:
7014
7358
  case GGML_OP_MUL:
7015
7359
  case GGML_OP_DIV:
7016
7360
  case GGML_OP_CONCAT:
@@ -7024,22 +7368,30 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7024
7368
  case GGML_OP_CPY:
7025
7369
  case GGML_OP_CONT:
7026
7370
  case GGML_OP_DUP:
7371
+ case GGML_OP_SILU_BACK:
7027
7372
  case GGML_OP_NORM:
7028
7373
  case GGML_OP_GROUP_NORM:
7029
7374
  case GGML_OP_RMS_NORM:
7375
+ case GGML_OP_RMS_NORM_BACK:
7030
7376
  case GGML_OP_DIAG_MASK_INF:
7031
7377
  case GGML_OP_SOFT_MAX:
7378
+ case GGML_OP_SOFT_MAX_BACK:
7032
7379
  case GGML_OP_ROPE:
7380
+ case GGML_OP_ROPE_BACK:
7033
7381
  case GGML_OP_MUL_MAT:
7034
7382
  case GGML_OP_MUL_MAT_ID:
7035
7383
  case GGML_OP_ARGSORT:
7384
+ case GGML_OP_SUM:
7036
7385
  case GGML_OP_SUM_ROWS:
7386
+ case GGML_OP_ARGMAX:
7387
+ case GGML_OP_COUNT_EQUAL:
7037
7388
  case GGML_OP_IM2COL:
7038
7389
  case GGML_OP_TIMESTEP_EMBEDDING:
7039
7390
  case GGML_OP_POOL_2D:
7040
7391
  case GGML_OP_RWKV_WKV6:
7041
7392
  case GGML_OP_LEAKY_RELU:
7042
7393
  case GGML_OP_FLASH_ATTN_EXT:
7394
+ case GGML_OP_OPT_STEP_ADAMW:
7043
7395
  break;
7044
7396
  default:
7045
7397
  std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -7060,9 +7412,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7060
7412
  } else {
7061
7413
  switch (node->op) {
7062
7414
  case GGML_OP_REPEAT:
7415
+ case GGML_OP_REPEAT_BACK:
7063
7416
  case GGML_OP_ACC:
7064
7417
  case GGML_OP_GET_ROWS:
7065
7418
  case GGML_OP_ADD:
7419
+ case GGML_OP_SUB:
7066
7420
  case GGML_OP_MUL:
7067
7421
  case GGML_OP_DIV:
7068
7422
  case GGML_OP_CONCAT:
@@ -7076,15 +7430,22 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7076
7430
  case GGML_OP_CPY:
7077
7431
  case GGML_OP_CONT:
7078
7432
  case GGML_OP_DUP:
7433
+ case GGML_OP_SILU_BACK:
7079
7434
  case GGML_OP_NORM:
7080
7435
  case GGML_OP_GROUP_NORM:
7081
7436
  case GGML_OP_RMS_NORM:
7437
+ case GGML_OP_RMS_NORM_BACK:
7082
7438
  case GGML_OP_UNARY:
7083
7439
  case GGML_OP_DIAG_MASK_INF:
7084
7440
  case GGML_OP_SOFT_MAX:
7441
+ case GGML_OP_SOFT_MAX_BACK:
7085
7442
  case GGML_OP_ROPE:
7443
+ case GGML_OP_ROPE_BACK:
7086
7444
  case GGML_OP_ARGSORT:
7445
+ case GGML_OP_SUM:
7087
7446
  case GGML_OP_SUM_ROWS:
7447
+ case GGML_OP_ARGMAX:
7448
+ case GGML_OP_COUNT_EQUAL:
7088
7449
  case GGML_OP_IM2COL:
7089
7450
  case GGML_OP_TIMESTEP_EMBEDDING:
7090
7451
  case GGML_OP_POOL_2D:
@@ -7105,6 +7466,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7105
7466
  case GGML_OP_REPEAT:
7106
7467
  ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
7107
7468
 
7469
+ break;
7470
+ case GGML_OP_REPEAT_BACK:
7471
+ ggml_vk_repeat_back(ctx, compute_ctx, src0, node, dryrun);
7472
+
7108
7473
  break;
7109
7474
  case GGML_OP_ACC:
7110
7475
  ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7117,6 +7482,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7117
7482
  case GGML_OP_ADD:
7118
7483
  ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
7119
7484
 
7485
+ break;
7486
+ case GGML_OP_SUB:
7487
+ ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
7488
+
7120
7489
  break;
7121
7490
  case GGML_OP_MUL:
7122
7491
  ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7163,6 +7532,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7163
7532
  case GGML_OP_DUP:
7164
7533
  ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
7165
7534
 
7535
+ break;
7536
+ case GGML_OP_SILU_BACK:
7537
+ ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
7538
+
7166
7539
  break;
7167
7540
  case GGML_OP_NORM:
7168
7541
  ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -7175,6 +7548,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7175
7548
  case GGML_OP_RMS_NORM:
7176
7549
  ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
7177
7550
 
7551
+ break;
7552
+ case GGML_OP_RMS_NORM_BACK:
7553
+ ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
7554
+
7178
7555
  break;
7179
7556
  case GGML_OP_UNARY:
7180
7557
  switch (ggml_get_unary_op(node)) {
@@ -7183,6 +7560,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7183
7560
  case GGML_UNARY_OP_GELU_QUICK:
7184
7561
  case GGML_UNARY_OP_RELU:
7185
7562
  case GGML_UNARY_OP_TANH:
7563
+ case GGML_UNARY_OP_SIGMOID:
7186
7564
  ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
7187
7565
  break;
7188
7566
  default:
@@ -7196,18 +7574,38 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7196
7574
  case GGML_OP_SOFT_MAX:
7197
7575
  ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
7198
7576
 
7577
+ break;
7578
+ case GGML_OP_SOFT_MAX_BACK:
7579
+ ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun);
7580
+
7199
7581
  break;
7200
7582
  case GGML_OP_ROPE:
7201
- ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
7583
+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
7584
+
7585
+ break;
7586
+ case GGML_OP_ROPE_BACK:
7587
+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun);
7202
7588
 
7203
7589
  break;
7204
7590
  case GGML_OP_ARGSORT:
7205
7591
  ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
7206
7592
 
7593
+ break;
7594
+ case GGML_OP_SUM:
7595
+ ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun);
7596
+
7207
7597
  break;
7208
7598
  case GGML_OP_SUM_ROWS:
7209
7599
  ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
7210
7600
 
7601
+ break;
7602
+ case GGML_OP_ARGMAX:
7603
+ ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
7604
+
7605
+ break;
7606
+ case GGML_OP_COUNT_EQUAL:
7607
+ ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun);
7608
+
7211
7609
  break;
7212
7610
  case GGML_OP_IM2COL:
7213
7611
  ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7242,6 +7640,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7242
7640
  case GGML_OP_RWKV_WKV6:
7243
7641
  ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
7244
7642
 
7643
+ break;
7644
+
7645
+ case GGML_OP_OPT_STEP_ADAMW:
7646
+ ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
7647
+
7245
7648
  break;
7246
7649
  default:
7247
7650
  return false;
@@ -7293,6 +7696,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7293
7696
  case GGML_OP_ADD:
7294
7697
  case GGML_OP_ACC:
7295
7698
  case GGML_OP_GET_ROWS:
7699
+ case GGML_OP_SUB:
7296
7700
  case GGML_OP_MUL:
7297
7701
  case GGML_OP_DIV:
7298
7702
  case GGML_OP_CONCAT:
@@ -7306,25 +7710,34 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7306
7710
  case GGML_OP_CPY:
7307
7711
  case GGML_OP_CONT:
7308
7712
  case GGML_OP_DUP:
7713
+ case GGML_OP_SILU_BACK:
7309
7714
  case GGML_OP_NORM:
7310
7715
  case GGML_OP_GROUP_NORM:
7311
7716
  case GGML_OP_RMS_NORM:
7717
+ case GGML_OP_RMS_NORM_BACK:
7312
7718
  case GGML_OP_DIAG_MASK_INF:
7313
7719
  case GGML_OP_SOFT_MAX:
7720
+ case GGML_OP_SOFT_MAX_BACK:
7314
7721
  case GGML_OP_ROPE:
7722
+ case GGML_OP_ROPE_BACK:
7315
7723
  case GGML_OP_RESHAPE:
7316
7724
  case GGML_OP_VIEW:
7317
7725
  case GGML_OP_PERMUTE:
7318
7726
  case GGML_OP_TRANSPOSE:
7319
7727
  case GGML_OP_NONE:
7320
7728
  case GGML_OP_ARGSORT:
7729
+ case GGML_OP_SUM:
7321
7730
  case GGML_OP_SUM_ROWS:
7731
+ case GGML_OP_ARGMAX:
7732
+ case GGML_OP_COUNT_EQUAL:
7322
7733
  case GGML_OP_IM2COL:
7323
7734
  case GGML_OP_TIMESTEP_EMBEDDING:
7324
7735
  case GGML_OP_POOL_2D:
7325
7736
  case GGML_OP_RWKV_WKV6:
7326
7737
  case GGML_OP_LEAKY_RELU:
7327
7738
  case GGML_OP_REPEAT:
7739
+ case GGML_OP_REPEAT_BACK:
7740
+ case GGML_OP_OPT_STEP_ADAMW:
7328
7741
  buf = tensor->buffer;
7329
7742
 
7330
7743
  break;
@@ -7335,6 +7748,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7335
7748
  case GGML_UNARY_OP_GELU_QUICK:
7336
7749
  case GGML_UNARY_OP_RELU:
7337
7750
  case GGML_UNARY_OP_TANH:
7751
+ case GGML_UNARY_OP_SIGMOID:
7338
7752
  buf = tensor->buffer;
7339
7753
  break;
7340
7754
  default:
@@ -7509,11 +7923,21 @@ static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
7509
7923
  UNUSED(buffer);
7510
7924
  }
7511
7925
 
7512
- static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
7926
+ static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
7513
7927
  VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
7514
7928
  if (tensor->view_src != nullptr) {
7515
7929
  GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
7516
7930
  }
7931
+ return GGML_STATUS_SUCCESS;
7932
+ }
7933
+
7934
+ static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
7935
+ VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
7936
+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
7937
+ vk_buffer buf = buf_ctx->dev_buffer;
7938
+
7939
+ uint32_t val32 = (uint32_t)value * 0x01010101;
7940
+ ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
7517
7941
  }
7518
7942
 
7519
7943
  static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
@@ -7560,7 +7984,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
7560
7984
  /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
7561
7985
  /* .get_base = */ ggml_backend_vk_buffer_get_base,
7562
7986
  /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
7563
- /* .memset_tensor = */ NULL,
7987
+ /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor,
7564
7988
  /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
7565
7989
  /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
7566
7990
  /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
@@ -8027,7 +8451,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8027
8451
  case GGML_UNARY_OP_SILU:
8028
8452
  case GGML_UNARY_OP_RELU:
8029
8453
  case GGML_UNARY_OP_TANH:
8030
- return ggml_is_contiguous(op->src[0]);
8454
+ case GGML_UNARY_OP_SIGMOID:
8455
+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
8031
8456
  default:
8032
8457
  return false;
8033
8458
  }
@@ -8035,13 +8460,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8035
8460
  case GGML_OP_MUL_MAT:
8036
8461
  case GGML_OP_MUL_MAT_ID:
8037
8462
  {
8463
+ ggml_type src0_type = op->src[0]->type;
8038
8464
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
8039
8465
  const vk_device& device = ggml_vk_get_device(ctx->device);
8040
- if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) {
8466
+ if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
8041
8467
  // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
8042
8468
  return false;
8043
8469
  }
8044
- switch (op->src[0]->type) {
8470
+ switch (src0_type) {
8045
8471
  case GGML_TYPE_F32:
8046
8472
  case GGML_TYPE_F16:
8047
8473
  case GGML_TYPE_Q4_0:
@@ -8054,6 +8480,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8054
8480
  case GGML_TYPE_Q4_K:
8055
8481
  case GGML_TYPE_Q5_K:
8056
8482
  case GGML_TYPE_Q6_K:
8483
+ case GGML_TYPE_IQ1_S:
8484
+ case GGML_TYPE_IQ1_M:
8057
8485
  case GGML_TYPE_IQ2_XXS:
8058
8486
  case GGML_TYPE_IQ2_XS:
8059
8487
  case GGML_TYPE_IQ2_S:
@@ -8128,6 +8556,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8128
8556
  //case GGML_TYPE_Q4_K:
8129
8557
  //case GGML_TYPE_Q5_K:
8130
8558
  //case GGML_TYPE_Q6_K:
8559
+ //case GGML_TYPE_IQ1_S:
8560
+ //case GGML_TYPE_IQ1_M:
8131
8561
  //case GGML_TYPE_IQ2_XXS:
8132
8562
  //case GGML_TYPE_IQ2_XS:
8133
8563
  //case GGML_TYPE_IQ2_S:
@@ -8151,6 +8581,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8151
8581
  case GGML_TYPE_Q5_0:
8152
8582
  case GGML_TYPE_Q5_1:
8153
8583
  case GGML_TYPE_Q8_0:
8584
+ case GGML_TYPE_IQ1_S:
8585
+ case GGML_TYPE_IQ1_M:
8154
8586
  case GGML_TYPE_IQ2_XXS:
8155
8587
  case GGML_TYPE_IQ2_XS:
8156
8588
  case GGML_TYPE_IQ2_S:
@@ -8206,17 +8638,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8206
8638
  } break;
8207
8639
  case GGML_OP_REPEAT:
8208
8640
  return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
8641
+ case GGML_OP_REPEAT_BACK:
8642
+ return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
8209
8643
  case GGML_OP_ROPE:
8210
- {
8211
- const int mode = ((const int32_t *) op->op_params)[2];
8212
- if (mode & GGML_ROPE_TYPE_MROPE) {
8213
- return false;
8214
- }
8215
- if (mode & GGML_ROPE_TYPE_VISION) {
8216
- return false;
8217
- }
8218
- return ggml_is_contiguous(op->src[0]);
8219
- }
8644
+ case GGML_OP_ROPE_BACK:
8220
8645
  case GGML_OP_NONE:
8221
8646
  case GGML_OP_RESHAPE:
8222
8647
  case GGML_OP_VIEW:
@@ -8228,26 +8653,35 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8228
8653
  case GGML_OP_RMS_NORM:
8229
8654
  return ggml_is_contiguous(op->src[0]);
8230
8655
  case GGML_OP_ADD:
8231
- case GGML_OP_ACC:
8656
+ case GGML_OP_SUB:
8232
8657
  case GGML_OP_MUL:
8233
8658
  case GGML_OP_DIV:
8234
- case GGML_OP_CONCAT:
8235
- case GGML_OP_UPSCALE:
8236
- case GGML_OP_SCALE:
8659
+ case GGML_OP_SILU_BACK:
8660
+ case GGML_OP_RMS_NORM_BACK:
8237
8661
  case GGML_OP_SQR:
8238
8662
  case GGML_OP_SIN:
8239
8663
  case GGML_OP_COS:
8240
8664
  case GGML_OP_CLAMP:
8665
+ return op->src[0]->type == GGML_TYPE_F32;
8666
+ case GGML_OP_ACC:
8667
+ case GGML_OP_CONCAT:
8668
+ case GGML_OP_UPSCALE:
8669
+ case GGML_OP_SCALE:
8241
8670
  case GGML_OP_PAD:
8242
8671
  case GGML_OP_DIAG_MASK_INF:
8243
8672
  case GGML_OP_SOFT_MAX:
8673
+ case GGML_OP_SOFT_MAX_BACK:
8244
8674
  case GGML_OP_ARGSORT:
8675
+ case GGML_OP_SUM:
8245
8676
  case GGML_OP_SUM_ROWS:
8677
+ case GGML_OP_ARGMAX:
8678
+ case GGML_OP_COUNT_EQUAL:
8246
8679
  case GGML_OP_IM2COL:
8247
8680
  case GGML_OP_TIMESTEP_EMBEDDING:
8248
8681
  case GGML_OP_POOL_2D:
8249
8682
  case GGML_OP_RWKV_WKV6:
8250
8683
  case GGML_OP_LEAKY_RELU:
8684
+ case GGML_OP_OPT_STEP_ADAMW:
8251
8685
  return true;
8252
8686
  default:
8253
8687
  return false;
@@ -8347,8 +8781,13 @@ ggml_backend_reg_t ggml_backend_vk_reg() {
8347
8781
  /* .iface = */ ggml_backend_vk_reg_i,
8348
8782
  /* .context = */ nullptr,
8349
8783
  };
8350
-
8351
- return &reg;
8784
+ try {
8785
+ ggml_vk_instance_init();
8786
+ return &reg;
8787
+ } catch (const vk::SystemError& e) {
8788
+ VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what());
8789
+ return nullptr;
8790
+ }
8352
8791
  }
8353
8792
 
8354
8793
  // Extension availability
@@ -8515,8 +8954,6 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8515
8954
 
8516
8955
  ggml_tensor * src0 = tensor->src[0];
8517
8956
  ggml_tensor * src1 = tensor->src[1];
8518
- ggml_tensor * src2 = tensor->src[2];
8519
- ggml_tensor * src3 = tensor->src[3];
8520
8957
 
8521
8958
  struct ggml_init_params iparams = {
8522
8959
  /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
@@ -8526,239 +8963,121 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8526
8963
 
8527
8964
  struct ggml_context * ggml_ctx = ggml_init(iparams);
8528
8965
 
8529
- struct ggml_tensor * src0_clone = nullptr;
8530
- struct ggml_tensor * src1_clone = nullptr;
8531
- struct ggml_tensor * src2_clone = nullptr;
8532
- struct ggml_tensor * src3_clone = nullptr;
8533
- struct ggml_tensor * tensor_clone = nullptr;
8534
-
8535
- size_t src0_size;
8536
- size_t src1_size;
8537
- size_t src2_size;
8538
- size_t src3_size;
8539
-
8540
- void * src0_buffer = nullptr;
8541
- void * src1_buffer = nullptr;
8542
- void * src2_buffer = nullptr;
8543
- void * src3_buffer = nullptr;
8544
-
8545
- if (src0 != nullptr) {
8546
- src0_clone = ggml_dup_tensor(ggml_ctx, src0);
8547
-
8548
- src0_size = ggml_nbytes(src0);
8549
-
8550
- src0_buffer = malloc(src0_size);
8551
- src0_clone->data = src0_buffer;
8552
- if (ggml_backend_buffer_is_host(src0->buffer)) {
8553
- memcpy(src0_clone->data, src0->data, src0_size);
8554
- memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
8555
- } else if (ggml_backend_buffer_is_vk(src0->buffer)) {
8556
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
8557
- vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
8558
- uint64_t offset = vk_tensor_offset(src0) + src0->view_offs;
8559
- if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
8560
- for (int i3 = 0; i3 < src0->ne[3]; i3++) {
8561
- for (int i2 = 0; i2 < src0->ne[2]; i2++) {
8562
- const int idx = i3*src0->ne[2] + i2;
8563
- ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
8564
- }
8565
- }
8566
-
8567
- src0_clone->nb[0] = src0->nb[0];
8568
- src0_clone->nb[1] = src0->nb[1];
8569
- for (int i = 2; i < GGML_MAX_DIMS; i++) {
8570
- src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1];
8571
- }
8572
- } else {
8573
- if (offset + src0_size >= buffer_gpu->size) {
8574
- src0_size = buffer_gpu->size - offset;
8575
- }
8576
- ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size);
8577
- memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
8578
- }
8579
- } else {
8580
- GGML_ABORT("fatal error");
8581
- }
8966
+ std::array<struct ggml_tensor *, 6> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
8967
+ std::array<size_t, 6> src_size = {0, 0, 0, 0, 0, 0};
8968
+ std::array<void *, 6> src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
8969
+ const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"};
8582
8970
 
8583
- if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
8584
- ggml_vk_print_tensor(src0, "src0");
8585
- }
8586
- }
8587
- if (src1 != nullptr) {
8588
- src1_clone = ggml_dup_tensor(ggml_ctx, src1);
8589
-
8590
- src1_size = ggml_nbytes(src1);
8591
-
8592
- src1_buffer = malloc(src1_size);
8593
- src1_clone->data = src1_buffer;
8594
- if (ggml_backend_buffer_is_host(src1->buffer)) {
8595
- memcpy(src1_clone->data, src1->data, src1_size);
8596
- memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
8597
- } else if (ggml_backend_buffer_is_vk(src1->buffer)) {
8598
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
8599
- vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
8600
- uint64_t offset = vk_tensor_offset(src1) + src1->view_offs;
8601
- if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
8602
- for (int i3 = 0; i3 < src1->ne[3]; i3++) {
8603
- for (int i2 = 0; i2 < src1->ne[2]; i2++) {
8604
- const int idx = i3*src1->ne[2] + i2;
8605
- ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
8606
- }
8607
- }
8608
-
8609
- src1_clone->nb[0] = src1->nb[0];
8610
- src1_clone->nb[1] = src1->nb[1];
8611
- for (int i = 2; i < GGML_MAX_DIMS; i++) {
8612
- src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1];
8613
- }
8614
- } else {
8615
- if (offset + src1_size >= buffer_gpu->size) {
8616
- src1_size = buffer_gpu->size - offset;
8617
- }
8618
- ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size);
8619
- memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
8620
- }
8621
- } else {
8622
- GGML_ABORT("fatal error");
8623
- }
8624
-
8625
- if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
8626
- ggml_vk_print_tensor(src1, "src1");
8627
- }
8628
- }
8629
- if (src2 != nullptr) {
8630
- src2_clone = ggml_dup_tensor(ggml_ctx, src2);
8631
-
8632
- src2_size = ggml_nbytes(src2);
8633
-
8634
- src2_buffer = malloc(src2_size);
8635
- src2_clone->data = src2_buffer;
8636
- if (ggml_backend_buffer_is_host(src2->buffer)) {
8637
- memcpy(src2_clone->data, src2->data, src2_size);
8638
- memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
8639
- } else if (ggml_backend_buffer_is_vk(src2->buffer)) {
8640
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context;
8641
- vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
8642
- uint64_t offset = vk_tensor_offset(src2) + src2->view_offs;
8643
- if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
8644
- for (int i3 = 0; i3 < src2->ne[3]; i3++) {
8645
- for (int i2 = 0; i2 < src2->ne[2]; i2++) {
8646
- const int idx = i3*src2->ne[2] + i2;
8647
- ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
8648
- }
8649
- }
8650
-
8651
- src2_clone->nb[0] = src2->nb[0];
8652
- src2_clone->nb[1] = src2->nb[1];
8653
- for (int i = 2; i < GGML_MAX_DIMS; i++) {
8654
- src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1];
8655
- }
8656
- } else {
8657
- if (offset + src2_size >= buffer_gpu->size) {
8658
- src2_size = buffer_gpu->size - offset;
8659
- }
8660
- ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size);
8661
- memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
8662
- }
8663
- } else {
8664
- GGML_ABORT("fatal error");
8665
- }
8971
+ struct ggml_tensor * tensor_clone = nullptr;
8666
8972
 
8667
- if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
8668
- ggml_vk_print_tensor(src2, "src2");
8973
+ for (int i = 0; i < 6; i++) {
8974
+ ggml_tensor * srci = tensor->src[i];
8975
+ if (srci == nullptr) {
8976
+ continue;
8669
8977
  }
8670
- }
8671
- if (src3 != nullptr) {
8672
- src3_clone = ggml_dup_tensor(ggml_ctx, src3);
8978
+ ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci);
8979
+ size_t srci_size = ggml_nbytes(srci);
8673
8980
 
8674
- src3_size = ggml_nbytes(src3);
8981
+ src_clone[i] = srci_clone;
8982
+ src_size[i] = ggml_nbytes(srci);
8983
+ src_buffer[i] = malloc(srci_size);
8675
8984
 
8676
- src3_buffer = malloc(src3_size);
8677
- src3_clone->data = src3_buffer;
8678
- if (ggml_backend_buffer_is_host(src3->buffer)) {
8679
- memcpy(src3_clone->data, src3->data, src3_size);
8680
- memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
8681
- } else if (ggml_backend_buffer_is_vk(src3->buffer)) {
8682
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context;
8985
+ srci_clone->data = src_buffer[i];
8986
+ if (ggml_backend_buffer_is_host(srci->buffer)) {
8987
+ memcpy(srci_clone->data, srci->data, srci_size);
8988
+ memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
8989
+ } else if (ggml_backend_buffer_is_vk(srci->buffer)) {
8990
+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context;
8683
8991
  vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
8684
- uint64_t offset = vk_tensor_offset(src3) + src3->view_offs;
8685
- if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) {
8686
- for (int i3 = 0; i3 < src3->ne[3]; i3++) {
8687
- for (int i2 = 0; i2 < src3->ne[2]; i2++) {
8688
- const int idx = i3*src3->ne[2] + i2;
8689
- ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]);
8992
+ uint64_t offset = vk_tensor_offset(srci) + srci->view_offs;
8993
+ if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) {
8994
+ for (int i3 = 0; i3 < srci->ne[3]; i3++) {
8995
+ for (int i2 = 0; i2 < srci->ne[2]; i2++) {
8996
+ const int idx = i3*srci->ne[2] + i2;
8997
+ ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]);
8690
8998
  }
8691
8999
  }
8692
9000
 
8693
- src3_clone->nb[0] = src3->nb[0];
8694
- src3_clone->nb[1] = src3->nb[1];
9001
+ srci_clone->nb[0] = srci->nb[0];
9002
+ srci_clone->nb[1] = srci->nb[1];
8695
9003
  for (int i = 2; i < GGML_MAX_DIMS; i++) {
8696
- src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1];
9004
+ srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1];
8697
9005
  }
8698
9006
  } else {
8699
- if (offset + src3_size >= buffer_gpu->size) {
8700
- src3_size = buffer_gpu->size - offset;
9007
+ if (offset + srci_size >= buffer_gpu->size) {
9008
+ srci_size = buffer_gpu->size - offset;
8701
9009
  }
8702
- ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size);
8703
- memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
9010
+ ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size);
9011
+ memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
8704
9012
  }
8705
9013
  } else {
8706
9014
  GGML_ABORT("fatal error");
8707
9015
  }
8708
9016
 
8709
9017
  if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
8710
- ggml_vk_print_tensor(src3, "src3");
9018
+ ggml_vk_print_tensor(srci, srci_name[i]);
8711
9019
  }
8712
9020
  }
8713
9021
 
8714
9022
  if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
8715
9023
  const float *params = (const float *)tensor->op_params;
8716
- tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]);
9024
+ tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
8717
9025
  } else if (tensor->op == GGML_OP_MUL_MAT) {
8718
- tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
9026
+ tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
8719
9027
  } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
8720
- tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
9028
+ tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
9029
+ } else if (tensor->op == GGML_OP_SUB) {
9030
+ tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
8721
9031
  } else if (tensor->op == GGML_OP_MUL) {
8722
- tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
9032
+ tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
8723
9033
  } else if (tensor->op == GGML_OP_DIV) {
8724
- tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone);
9034
+ tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
8725
9035
  } else if (tensor->op == GGML_OP_CONCAT) {
8726
- tensor_clone = ggml_concat(ggml_ctx, src0_clone, src1_clone, *(int *)tensor->op_params);
9036
+ tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
8727
9037
  } else if (tensor->op == GGML_OP_UPSCALE) {
8728
- tensor_clone = ggml_upscale_ext(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
9038
+ tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
8729
9039
  } else if (tensor->op == GGML_OP_SCALE) {
8730
- tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
9040
+ tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]);
8731
9041
  } else if (tensor->op == GGML_OP_SQR) {
8732
- tensor_clone = ggml_sqr(ggml_ctx, src0_clone);
9042
+ tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
8733
9043
  } else if (tensor->op == GGML_OP_SIN) {
8734
- tensor_clone = ggml_sin(ggml_ctx, src0_clone);
9044
+ tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
8735
9045
  } else if (tensor->op == GGML_OP_COS) {
8736
- tensor_clone = ggml_cos(ggml_ctx, src0_clone);
9046
+ tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
8737
9047
  } else if (tensor->op == GGML_OP_CLAMP) {
8738
- tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
9048
+ tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8739
9049
  } else if (tensor->op == GGML_OP_PAD) {
8740
- tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]);
9050
+ tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
8741
9051
  } else if (tensor->op == GGML_OP_REPEAT) {
8742
- tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor);
9052
+ tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
9053
+ } else if (tensor->op == GGML_OP_REPEAT_BACK) {
9054
+ tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);
8743
9055
  } else if (tensor->op == GGML_OP_ADD) {
8744
- tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone);
9056
+ tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
8745
9057
  } else if (tensor->op == GGML_OP_ACC) {
8746
- tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
9058
+ tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
8747
9059
  } else if (tensor->op == GGML_OP_NORM) {
8748
- tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
9060
+ tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
8749
9061
  } else if (tensor->op == GGML_OP_GROUP_NORM) {
8750
- tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
9062
+ tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
8751
9063
  } else if (tensor->op == GGML_OP_RMS_NORM) {
8752
- tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
9064
+ tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9065
+ } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
9066
+ const float eps = ((float *) tensor->op_params)[0];
9067
+ tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
9068
+ } else if (tensor->op == GGML_OP_SILU_BACK) {
9069
+ tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
8753
9070
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
8754
9071
  if (src1 != nullptr) {
8755
- tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
9072
+ tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8756
9073
  } else {
8757
- tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
9074
+ tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
8758
9075
  }
9076
+ } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
9077
+ tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8759
9078
  } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
8760
- tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params);
8761
- } else if (tensor->op == GGML_OP_ROPE) {
9079
+ tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
9080
+ } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
8762
9081
  const int n_dims = ((int32_t *) tensor->op_params)[1];
8763
9082
  const int mode = ((int32_t *) tensor->op_params)[2];
8764
9083
  //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
@@ -8769,23 +9088,39 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8769
9088
  const float attn_factor = ((float *) tensor->op_params)[8];
8770
9089
  const float beta_fast = ((float *) tensor->op_params)[9];
8771
9090
  const float beta_slow = ((float *) tensor->op_params)[10];
8772
- tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9091
+ if (mode & GGML_ROPE_TYPE_MROPE) {
9092
+ int32_t *sections = ((int32_t *) tensor->op_params) + 11;
9093
+ if (tensor->op == GGML_OP_ROPE) {
9094
+ tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9095
+ } else {
9096
+ tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9097
+ }
9098
+ } else {
9099
+ if (tensor->op == GGML_OP_ROPE) {
9100
+ tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9101
+ } else {
9102
+ tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9103
+ }
9104
+ }
8773
9105
  } else if (tensor->op == GGML_OP_UNARY) {
8774
9106
  switch (ggml_get_unary_op(tensor)) {
8775
9107
  case GGML_UNARY_OP_SILU:
8776
- tensor_clone = ggml_silu(ggml_ctx, src0_clone);
9108
+ tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
8777
9109
  break;
8778
9110
  case GGML_UNARY_OP_GELU:
8779
- tensor_clone = ggml_gelu(ggml_ctx, src0_clone);
9111
+ tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
8780
9112
  break;
8781
9113
  case GGML_UNARY_OP_GELU_QUICK:
8782
- tensor_clone = ggml_gelu_quick(ggml_ctx, src0_clone);
9114
+ tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
8783
9115
  break;
8784
9116
  case GGML_UNARY_OP_RELU:
8785
- tensor_clone = ggml_relu(ggml_ctx, src0_clone);
9117
+ tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
8786
9118
  break;
8787
9119
  case GGML_UNARY_OP_TANH:
8788
- tensor_clone = ggml_tanh(ggml_ctx, src0_clone);
9120
+ tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
9121
+ break;
9122
+ case GGML_UNARY_OP_SIGMOID:
9123
+ tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
8789
9124
  break;
8790
9125
  default:
8791
9126
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
@@ -8793,28 +9128,34 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8793
9128
  }
8794
9129
  } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
8795
9130
  if (src1 == nullptr) {
8796
- tensor_clone = ggml_dup(ggml_ctx, src0_clone);
9131
+ tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
8797
9132
  tensor_clone->type = tensor->type;
8798
9133
  } else {
8799
- tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone);
9134
+ tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
8800
9135
  }
8801
9136
  } else if (tensor->op == GGML_OP_CONT) {
8802
- tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
9137
+ tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
8803
9138
  } else if (tensor->op == GGML_OP_RESHAPE) {
8804
- tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
9139
+ tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
8805
9140
  } else if (tensor->op == GGML_OP_VIEW) {
8806
- tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
9141
+ tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
8807
9142
  } else if (tensor->op == GGML_OP_PERMUTE) {
8808
9143
  int32_t * params = (int32_t *)tensor->op_params;
8809
- tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]);
9144
+ tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]);
8810
9145
  } else if (tensor->op == GGML_OP_TRANSPOSE) {
8811
- tensor_clone = ggml_transpose(ggml_ctx, src0_clone);
9146
+ tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]);
8812
9147
  } else if (tensor->op == GGML_OP_GET_ROWS) {
8813
- tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone);
9148
+ tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
8814
9149
  } else if (tensor->op == GGML_OP_ARGSORT) {
8815
- tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
9150
+ tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
9151
+ } else if (tensor->op == GGML_OP_SUM) {
9152
+ tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
8816
9153
  } else if (tensor->op == GGML_OP_SUM_ROWS) {
8817
- tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
9154
+ tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
9155
+ } else if (tensor->op == GGML_OP_ARGMAX) {
9156
+ tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
9157
+ } else if (tensor->op == GGML_OP_COUNT_EQUAL) {
9158
+ tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
8818
9159
  } else if (tensor->op == GGML_OP_IM2COL) {
8819
9160
  const int32_t s0 = tensor->op_params[0];
8820
9161
  const int32_t s1 = tensor->op_params[1];
@@ -8824,11 +9165,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8824
9165
  const int32_t d1 = tensor->op_params[5];
8825
9166
 
8826
9167
  const bool is_2D = tensor->op_params[6] == 1;
8827
- tensor_clone = ggml_im2col(ggml_ctx, src0_clone, src1_clone, s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
9168
+ tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
8828
9169
  } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
8829
9170
  const int32_t dim = tensor->op_params[0];
8830
9171
  const int32_t max_period = tensor->op_params[1];
8831
- tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
9172
+ tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
8832
9173
  } else if (tensor->op == GGML_OP_POOL_2D) {
8833
9174
  enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
8834
9175
  const int32_t k0 = tensor->op_params[1];
@@ -8838,13 +9179,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8838
9179
  const int32_t p0 = tensor->op_params[5];
8839
9180
  const int32_t p1 = tensor->op_params[6];
8840
9181
 
8841
- tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1);
9182
+ tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
8842
9183
  } else if (tensor->op == GGML_OP_LEAKY_RELU) {
8843
9184
  const float * op_params = (const float *)tensor->op_params;
8844
- tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
9185
+ tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
8845
9186
  } else if (tensor->op == GGML_OP_RWKV_WKV6) {
8846
- tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
8847
- tensor->src[4], tensor->src[5]);
9187
+ tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
9188
+ src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
9189
+ } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
9190
+ src_clone[0]->flags = src0->flags;
9191
+ tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
9192
+ src_clone[2], src_clone[3], src_clone[4]);
8848
9193
  }
8849
9194
  else {
8850
9195
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
@@ -8866,11 +9211,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8866
9211
  memcpy(comp_result, tensor_clone->data, comp_size);
8867
9212
  memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
8868
9213
 
8869
- if (src0 != nullptr) {
8870
- free(src0_buffer);
8871
- }
8872
- if (src1 != nullptr) {
8873
- free(src1_buffer);
9214
+ for (int i = 0; i < 6; i++) {
9215
+ if (src_buffer[i] != nullptr) {
9216
+ free(src_buffer[i]);
9217
+ }
8874
9218
  }
8875
9219
 
8876
9220
  ggml_free(ggml_ctx);
@@ -8934,6 +9278,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
8934
9278
  } else if (tensor->type == GGML_TYPE_I32) {
8935
9279
  correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
8936
9280
  result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
9281
+ } else if (tensor->type == GGML_TYPE_I64) {
9282
+ correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
9283
+ result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
8937
9284
  } else {
8938
9285
  std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
8939
9286
  }