@fugood/llama.node 0.3.13 → 0.3.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +1 -1
  18. package/package.json +1 -1
  19. package/src/LlamaContext.cpp +98 -76
  20. package/src/LlamaContext.h +1 -1
  21. package/src/common.hpp +1 -2
  22. package/src/llama.cpp/.github/workflows/build.yml +60 -10
  23. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  24. package/src/llama.cpp/common/CMakeLists.txt +3 -3
  25. package/src/llama.cpp/common/arg.cpp +112 -11
  26. package/src/llama.cpp/common/chat.cpp +960 -266
  27. package/src/llama.cpp/common/chat.h +135 -0
  28. package/src/llama.cpp/common/common.cpp +27 -171
  29. package/src/llama.cpp/common/common.h +27 -67
  30. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  31. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  32. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
  33. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  34. package/src/llama.cpp/common/sampling.cpp +45 -7
  35. package/src/llama.cpp/common/speculative.cpp +6 -5
  36. package/src/llama.cpp/common/speculative.h +1 -1
  37. package/src/llama.cpp/docs/build.md +45 -7
  38. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
  39. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
  40. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  41. package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -3
  42. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
  43. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  44. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  45. package/src/llama.cpp/examples/llava/clip.h +19 -3
  46. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  47. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  48. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  49. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
  50. package/src/llama.cpp/examples/main/main.cpp +73 -28
  51. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
  52. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
  53. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  54. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  55. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  56. package/src/llama.cpp/examples/run/run.cpp +110 -67
  57. package/src/llama.cpp/examples/server/server.cpp +82 -87
  58. package/src/llama.cpp/examples/server/utils.hpp +94 -107
  59. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  60. package/src/llama.cpp/examples/tts/tts.cpp +251 -142
  61. package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
  62. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  63. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  64. package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
  65. package/src/llama.cpp/ggml/include/ggml.h +5 -1
  66. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  67. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  68. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  69. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  70. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  71. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
  72. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  73. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
  74. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  75. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  76. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  77. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
  78. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1396 -386
  79. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1432 -151
  80. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
  81. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  82. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  83. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  84. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  85. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
  86. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  87. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  89. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
  90. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
  91. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  92. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +220 -116
  93. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  94. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  95. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
  96. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  97. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  98. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
  99. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  100. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  101. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  102. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  103. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  104. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
  105. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
  106. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  107. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +168 -721
  108. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
  109. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
  110. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  111. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  112. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +146 -42
  113. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +13 -3
  114. package/src/llama.cpp/ggml/src/ggml.c +8 -3
  115. package/src/llama.cpp/include/llama.h +19 -5
  116. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  117. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  118. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  119. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  120. package/src/llama.cpp/requirements.txt +1 -0
  121. package/src/llama.cpp/src/llama-arch.cpp +21 -0
  122. package/src/llama.cpp/src/llama-arch.h +1 -0
  123. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  124. package/src/llama.cpp/src/llama-grammar.cpp +182 -182
  125. package/src/llama.cpp/src/llama-grammar.h +12 -3
  126. package/src/llama.cpp/src/llama-kv-cache.h +1 -0
  127. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  128. package/src/llama.cpp/src/llama-model.cpp +69 -5
  129. package/src/llama.cpp/src/llama-sampling.cpp +43 -10
  130. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  131. package/src/llama.cpp/src/llama.cpp +147 -0
  132. package/src/llama.cpp/tests/test-backend-ops.cpp +166 -110
  133. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  134. package/src/llama.cpp/tests/test-chat.cpp +593 -395
  135. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  136. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  137. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  138. package/src/llama.cpp/common/chat.hpp +0 -55
  139. /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
@@ -241,15 +241,19 @@ struct vk_device_struct {
241
241
  vk_pipeline pipeline_norm_f32;
242
242
  vk_pipeline pipeline_group_norm_f32;
243
243
  vk_pipeline pipeline_rms_norm_f32;
244
+ vk_pipeline pipeline_rms_norm_back_f32;
244
245
  vk_pipeline pipeline_gelu_f32;
245
246
  vk_pipeline pipeline_gelu_quick_f32;
246
247
  vk_pipeline pipeline_silu_f32;
248
+ vk_pipeline pipeline_silu_back_f32;
247
249
  vk_pipeline pipeline_relu_f32;
248
250
  vk_pipeline pipeline_leaky_relu_f32;
249
251
  vk_pipeline pipeline_tanh_f32;
252
+ vk_pipeline pipeline_sigmoid_f32;
250
253
  vk_pipeline pipeline_diag_mask_inf_f32;
251
254
  vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
252
255
  vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
256
+ vk_pipeline pipeline_soft_max_back_f32;
253
257
  vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
254
258
  vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
255
259
  vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -504,6 +508,7 @@ struct vk_op_rope_push_constants {
504
508
  uint32_t s1;
505
509
  uint32_t s2;
506
510
  int32_t sections[4];
511
+ uint32_t is_back;
507
512
  };
508
513
 
509
514
  struct vk_op_soft_max_push_constants {
@@ -1987,6 +1992,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1987
1992
  }
1988
1993
  } else if (device->vendor_id == VK_VENDOR_ID_INTEL)
1989
1994
  rm_stdq = 2;
1995
+ uint32_t rm_iq = 2 * rm_kq;
1990
1996
 
1991
1997
  for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
1992
1998
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
@@ -2001,15 +2007,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
2001
2007
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2002
2008
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2003
2009
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2004
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2005
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2006
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2007
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2008
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2009
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2010
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2011
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2012
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
2010
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2011
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2012
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2013
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2014
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2015
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2016
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2017
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2018
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2013
2019
 
2014
2020
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2015
2021
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
@@ -2023,15 +2029,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
2023
2029
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2024
2030
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2025
2031
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2026
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2027
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2028
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2029
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2030
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2031
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2032
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2033
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2034
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
2032
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2033
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2034
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2035
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2036
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2037
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2038
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2039
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2040
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2035
2041
  }
2036
2042
 
2037
2043
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -2046,15 +2052,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
2046
2052
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2047
2053
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2048
2054
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2049
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2050
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2051
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2052
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2053
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2054
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2055
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2056
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2057
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
2055
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2056
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2057
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2058
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2059
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2060
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2061
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2062
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2063
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2058
2064
 
2059
2065
  // dequant shaders
2060
2066
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -2121,6 +2127,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2121
2127
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2122
2128
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2123
2129
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2130
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2124
2131
 
2125
2132
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2126
2133
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -2180,9 +2187,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
2180
2187
  ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2181
2188
  ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2182
2189
  ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2190
+ ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2183
2191
  ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2184
2192
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2185
2193
  ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2194
+ ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2186
2195
 
2187
2196
  ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
2188
2197
 
@@ -2190,6 +2199,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2190
2199
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2191
2200
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2192
2201
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2202
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2193
2203
 
2194
2204
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2195
2205
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -4183,7 +4193,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4183
4193
  }
4184
4194
  if (qy_needs_dequant) {
4185
4195
  d_Y = ctx->prealloc_y;
4186
- GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03);
4196
+ GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
4187
4197
  } else {
4188
4198
  d_Y = d_Qy;
4189
4199
  y_buf_offset = qy_buf_offset;
@@ -4760,7 +4770,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4760
4770
  }
4761
4771
  if (qy_needs_dequant) {
4762
4772
  d_Y = ctx->prealloc_y;
4763
- GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03);
4773
+ GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
4764
4774
  } else {
4765
4775
  d_Y = d_Qy;
4766
4776
  y_buf_offset = qy_buf_offset;
@@ -5283,6 +5293,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5283
5293
  case GGML_OP_CONT:
5284
5294
  case GGML_OP_DUP:
5285
5295
  return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
5296
+ case GGML_OP_SILU_BACK:
5297
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5298
+ return ctx->device->pipeline_silu_back_f32;
5299
+ }
5300
+ return nullptr;
5286
5301
  case GGML_OP_NORM:
5287
5302
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5288
5303
  return ctx->device->pipeline_norm_f32;
@@ -5298,6 +5313,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5298
5313
  return ctx->device->pipeline_rms_norm_f32;
5299
5314
  }
5300
5315
  return nullptr;
5316
+ case GGML_OP_RMS_NORM_BACK:
5317
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5318
+ return ctx->device->pipeline_rms_norm_back_f32;
5319
+ }
5320
+ return nullptr;
5301
5321
  case GGML_OP_UNARY:
5302
5322
  switch (ggml_get_unary_op(dst)) {
5303
5323
  case GGML_UNARY_OP_SILU:
@@ -5325,6 +5345,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5325
5345
  return ctx->device->pipeline_tanh_f32;
5326
5346
  }
5327
5347
  break;
5348
+ case GGML_UNARY_OP_SIGMOID:
5349
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5350
+ return ctx->device->pipeline_sigmoid_f32;
5351
+ }
5352
+ break;
5328
5353
  default:
5329
5354
  break;
5330
5355
  }
@@ -5344,7 +5369,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5344
5369
  return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
5345
5370
  }
5346
5371
  return nullptr;
5372
+ case GGML_OP_SOFT_MAX_BACK:
5373
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5374
+ return ctx->device->pipeline_soft_max_back_f32;
5375
+ }
5376
+ return nullptr;
5347
5377
  case GGML_OP_ROPE:
5378
+ case GGML_OP_ROPE_BACK:
5348
5379
  {
5349
5380
  const int mode = ((const int32_t *) dst->op_params)[2];
5350
5381
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
@@ -5672,7 +5703,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5672
5703
  switch (op) {
5673
5704
  case GGML_OP_NORM:
5674
5705
  case GGML_OP_RMS_NORM:
5706
+ case GGML_OP_RMS_NORM_BACK:
5675
5707
  case GGML_OP_SOFT_MAX:
5708
+ case GGML_OP_SOFT_MAX_BACK:
5676
5709
  case GGML_OP_SUM_ROWS:
5677
5710
  case GGML_OP_ARGMAX:
5678
5711
  {
@@ -5696,6 +5729,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5696
5729
  } break;
5697
5730
  case GGML_OP_DIAG_MASK_INF:
5698
5731
  case GGML_OP_ROPE:
5732
+ case GGML_OP_ROPE_BACK:
5699
5733
  elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
5700
5734
  break;
5701
5735
  case GGML_OP_GET_ROWS:
@@ -5791,7 +5825,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5791
5825
 
5792
5826
  ggml_vk_sync_buffers(subctx);
5793
5827
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5794
- } else if (op == GGML_OP_ROPE) {
5828
+ } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
5795
5829
  // Empty src2 is possible in rope, but the shader needs a buffer
5796
5830
  vk_subbuffer subbuf_z;
5797
5831
  if (use_src2) {
@@ -6313,6 +6347,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
6313
6347
  }, dryrun);
6314
6348
  }
6315
6349
 
6350
+ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6351
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6352
+ }
6353
+
6316
6354
  static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6317
6355
  float * op_params = (float *)dst->op_params;
6318
6356
 
@@ -6335,6 +6373,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
6335
6373
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6336
6374
  }
6337
6375
 
6376
+ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6377
+ float * op_params = (float *)dst->op_params;
6378
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6379
+ }
6380
+
6338
6381
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6339
6382
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6340
6383
  }
@@ -6370,7 +6413,12 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
6370
6413
  }, dryrun);
6371
6414
  }
6372
6415
 
6373
- static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
6416
+ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6417
+ float * op_params = (float *)dst->op_params;
6418
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun);
6419
+ }
6420
+
6421
+ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
6374
6422
  const int n_dims = ((int32_t *) dst->op_params)[1];
6375
6423
  const int mode = ((int32_t *) dst->op_params)[2];
6376
6424
  // const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -6398,7 +6446,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
6398
6446
  (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
6399
6447
  freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
6400
6448
  src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
6401
- sections[0], sections[1], sections[2], sections[3],
6449
+ sections[0], sections[1], sections[2], sections[3], backprop
6402
6450
  }, dryrun);
6403
6451
  }
6404
6452
 
@@ -7295,6 +7343,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7295
7343
  case GGML_UNARY_OP_GELU_QUICK:
7296
7344
  case GGML_UNARY_OP_RELU:
7297
7345
  case GGML_UNARY_OP_TANH:
7346
+ case GGML_UNARY_OP_SIGMOID:
7298
7347
  break;
7299
7348
  default:
7300
7349
  return false;
@@ -7319,12 +7368,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7319
7368
  case GGML_OP_CPY:
7320
7369
  case GGML_OP_CONT:
7321
7370
  case GGML_OP_DUP:
7371
+ case GGML_OP_SILU_BACK:
7322
7372
  case GGML_OP_NORM:
7323
7373
  case GGML_OP_GROUP_NORM:
7324
7374
  case GGML_OP_RMS_NORM:
7375
+ case GGML_OP_RMS_NORM_BACK:
7325
7376
  case GGML_OP_DIAG_MASK_INF:
7326
7377
  case GGML_OP_SOFT_MAX:
7378
+ case GGML_OP_SOFT_MAX_BACK:
7327
7379
  case GGML_OP_ROPE:
7380
+ case GGML_OP_ROPE_BACK:
7328
7381
  case GGML_OP_MUL_MAT:
7329
7382
  case GGML_OP_MUL_MAT_ID:
7330
7383
  case GGML_OP_ARGSORT:
@@ -7377,13 +7430,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7377
7430
  case GGML_OP_CPY:
7378
7431
  case GGML_OP_CONT:
7379
7432
  case GGML_OP_DUP:
7433
+ case GGML_OP_SILU_BACK:
7380
7434
  case GGML_OP_NORM:
7381
7435
  case GGML_OP_GROUP_NORM:
7382
7436
  case GGML_OP_RMS_NORM:
7437
+ case GGML_OP_RMS_NORM_BACK:
7383
7438
  case GGML_OP_UNARY:
7384
7439
  case GGML_OP_DIAG_MASK_INF:
7385
7440
  case GGML_OP_SOFT_MAX:
7441
+ case GGML_OP_SOFT_MAX_BACK:
7386
7442
  case GGML_OP_ROPE:
7443
+ case GGML_OP_ROPE_BACK:
7387
7444
  case GGML_OP_ARGSORT:
7388
7445
  case GGML_OP_SUM:
7389
7446
  case GGML_OP_SUM_ROWS:
@@ -7475,6 +7532,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7475
7532
  case GGML_OP_DUP:
7476
7533
  ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
7477
7534
 
7535
+ break;
7536
+ case GGML_OP_SILU_BACK:
7537
+ ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
7538
+
7478
7539
  break;
7479
7540
  case GGML_OP_NORM:
7480
7541
  ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -7487,6 +7548,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7487
7548
  case GGML_OP_RMS_NORM:
7488
7549
  ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
7489
7550
 
7551
+ break;
7552
+ case GGML_OP_RMS_NORM_BACK:
7553
+ ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
7554
+
7490
7555
  break;
7491
7556
  case GGML_OP_UNARY:
7492
7557
  switch (ggml_get_unary_op(node)) {
@@ -7495,6 +7560,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7495
7560
  case GGML_UNARY_OP_GELU_QUICK:
7496
7561
  case GGML_UNARY_OP_RELU:
7497
7562
  case GGML_UNARY_OP_TANH:
7563
+ case GGML_UNARY_OP_SIGMOID:
7498
7564
  ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
7499
7565
  break;
7500
7566
  default:
@@ -7508,9 +7574,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7508
7574
  case GGML_OP_SOFT_MAX:
7509
7575
  ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
7510
7576
 
7577
+ break;
7578
+ case GGML_OP_SOFT_MAX_BACK:
7579
+ ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun);
7580
+
7511
7581
  break;
7512
7582
  case GGML_OP_ROPE:
7513
- ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
7583
+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
7584
+
7585
+ break;
7586
+ case GGML_OP_ROPE_BACK:
7587
+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun);
7514
7588
 
7515
7589
  break;
7516
7590
  case GGML_OP_ARGSORT:
@@ -7636,12 +7710,16 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7636
7710
  case GGML_OP_CPY:
7637
7711
  case GGML_OP_CONT:
7638
7712
  case GGML_OP_DUP:
7713
+ case GGML_OP_SILU_BACK:
7639
7714
  case GGML_OP_NORM:
7640
7715
  case GGML_OP_GROUP_NORM:
7641
7716
  case GGML_OP_RMS_NORM:
7717
+ case GGML_OP_RMS_NORM_BACK:
7642
7718
  case GGML_OP_DIAG_MASK_INF:
7643
7719
  case GGML_OP_SOFT_MAX:
7720
+ case GGML_OP_SOFT_MAX_BACK:
7644
7721
  case GGML_OP_ROPE:
7722
+ case GGML_OP_ROPE_BACK:
7645
7723
  case GGML_OP_RESHAPE:
7646
7724
  case GGML_OP_VIEW:
7647
7725
  case GGML_OP_PERMUTE:
@@ -7670,6 +7748,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7670
7748
  case GGML_UNARY_OP_GELU_QUICK:
7671
7749
  case GGML_UNARY_OP_RELU:
7672
7750
  case GGML_UNARY_OP_TANH:
7751
+ case GGML_UNARY_OP_SIGMOID:
7673
7752
  buf = tensor->buffer;
7674
7753
  break;
7675
7754
  default:
@@ -7844,11 +7923,12 @@ static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
7844
7923
  UNUSED(buffer);
7845
7924
  }
7846
7925
 
7847
- static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
7926
+ static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
7848
7927
  VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
7849
7928
  if (tensor->view_src != nullptr) {
7850
7929
  GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
7851
7930
  }
7931
+ return GGML_STATUS_SUCCESS;
7852
7932
  }
7853
7933
 
7854
7934
  static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
@@ -8371,7 +8451,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8371
8451
  case GGML_UNARY_OP_SILU:
8372
8452
  case GGML_UNARY_OP_RELU:
8373
8453
  case GGML_UNARY_OP_TANH:
8374
- return ggml_is_contiguous(op->src[0]);
8454
+ case GGML_UNARY_OP_SIGMOID:
8455
+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
8375
8456
  default:
8376
8457
  return false;
8377
8458
  }
@@ -8560,6 +8641,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8560
8641
  case GGML_OP_REPEAT_BACK:
8561
8642
  return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
8562
8643
  case GGML_OP_ROPE:
8644
+ case GGML_OP_ROPE_BACK:
8563
8645
  case GGML_OP_NONE:
8564
8646
  case GGML_OP_RESHAPE:
8565
8647
  case GGML_OP_VIEW:
@@ -8571,20 +8653,24 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8571
8653
  case GGML_OP_RMS_NORM:
8572
8654
  return ggml_is_contiguous(op->src[0]);
8573
8655
  case GGML_OP_ADD:
8574
- case GGML_OP_ACC:
8575
8656
  case GGML_OP_SUB:
8576
8657
  case GGML_OP_MUL:
8577
8658
  case GGML_OP_DIV:
8578
- case GGML_OP_CONCAT:
8579
- case GGML_OP_UPSCALE:
8580
- case GGML_OP_SCALE:
8659
+ case GGML_OP_SILU_BACK:
8660
+ case GGML_OP_RMS_NORM_BACK:
8581
8661
  case GGML_OP_SQR:
8582
8662
  case GGML_OP_SIN:
8583
8663
  case GGML_OP_COS:
8584
8664
  case GGML_OP_CLAMP:
8665
+ return op->src[0]->type == GGML_TYPE_F32;
8666
+ case GGML_OP_ACC:
8667
+ case GGML_OP_CONCAT:
8668
+ case GGML_OP_UPSCALE:
8669
+ case GGML_OP_SCALE:
8585
8670
  case GGML_OP_PAD:
8586
8671
  case GGML_OP_DIAG_MASK_INF:
8587
8672
  case GGML_OP_SOFT_MAX:
8673
+ case GGML_OP_SOFT_MAX_BACK:
8588
8674
  case GGML_OP_ARGSORT:
8589
8675
  case GGML_OP_SUM:
8590
8676
  case GGML_OP_SUM_ROWS:
@@ -8976,15 +9062,22 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8976
9062
  tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
8977
9063
  } else if (tensor->op == GGML_OP_RMS_NORM) {
8978
9064
  tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9065
+ } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
9066
+ const float eps = ((float *) tensor->op_params)[0];
9067
+ tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
9068
+ } else if (tensor->op == GGML_OP_SILU_BACK) {
9069
+ tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
8979
9070
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
8980
9071
  if (src1 != nullptr) {
8981
9072
  tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8982
9073
  } else {
8983
9074
  tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
8984
9075
  }
9076
+ } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
9077
+ tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8985
9078
  } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
8986
9079
  tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
8987
- } else if (tensor->op == GGML_OP_ROPE) {
9080
+ } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
8988
9081
  const int n_dims = ((int32_t *) tensor->op_params)[1];
8989
9082
  const int mode = ((int32_t *) tensor->op_params)[2];
8990
9083
  //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
@@ -8997,9 +9090,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8997
9090
  const float beta_slow = ((float *) tensor->op_params)[10];
8998
9091
  if (mode & GGML_ROPE_TYPE_MROPE) {
8999
9092
  int32_t *sections = ((int32_t *) tensor->op_params) + 11;
9000
- tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9093
+ if (tensor->op == GGML_OP_ROPE) {
9094
+ tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9095
+ } else {
9096
+ tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9097
+ }
9001
9098
  } else {
9002
- tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9099
+ if (tensor->op == GGML_OP_ROPE) {
9100
+ tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9101
+ } else {
9102
+ tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9103
+ }
9003
9104
  }
9004
9105
  } else if (tensor->op == GGML_OP_UNARY) {
9005
9106
  switch (ggml_get_unary_op(tensor)) {
@@ -9018,6 +9119,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9018
9119
  case GGML_UNARY_OP_TANH:
9019
9120
  tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
9020
9121
  break;
9122
+ case GGML_UNARY_OP_SIGMOID:
9123
+ tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
9124
+ break;
9021
9125
  default:
9022
9126
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
9023
9127
  GGML_ABORT("fatal error");
@@ -325,11 +325,17 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
325
325
  string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
326
326
 
327
327
  for (const auto& tname : type_names) {
328
+ std::string load_vec_quant = "2";
329
+ if ((tname == "q4_0") || (tname == "q4_1"))
330
+ load_vec_quant = "8";
331
+ else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
332
+ load_vec_quant = "4";
333
+
328
334
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
329
335
  // For unaligned, load one at a time for f32/f16, or two at a time for quants
330
- std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
336
+ std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant;
331
337
  // For aligned matmul loads
332
- std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
338
+ std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant;
333
339
 
334
340
  // don't generate f32 variants for coopmat2
335
341
  if (!coopmat2) {
@@ -396,7 +402,7 @@ void process_shaders() {
396
402
  for (const auto& tname : type_names) {
397
403
  // mul mat vec
398
404
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
399
- std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
405
+ std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
400
406
 
401
407
  string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
402
408
  string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
@@ -427,6 +433,7 @@ void process_shaders() {
427
433
  string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
428
434
  string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
429
435
  string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
436
+ string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
430
437
 
431
438
  string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
432
439
  string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
@@ -477,14 +484,17 @@ void process_shaders() {
477
484
  string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
478
485
  string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
479
486
  string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
487
+ string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
480
488
  string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
481
489
  string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
482
490
  string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
491
+ string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
483
492
 
484
493
  string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
485
494
 
486
495
  string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
487
496
  string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
497
+ string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
488
498
 
489
499
  string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
490
500
  string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
@@ -240,7 +240,11 @@ void ggml_log_callback_default(enum ggml_log_level level, const char * text, voi
240
240
 
241
241
 
242
242
  void * ggml_aligned_malloc(size_t size) {
243
+ #if defined(__s390x__)
244
+ const int alignment = 256;
245
+ #else
243
246
  const int alignment = 64;
247
+ #endif
244
248
 
245
249
  #if defined(_MSC_VER) || defined(__MINGW32__)
246
250
  return _aligned_malloc(size, alignment);
@@ -561,9 +565,9 @@ FILE * ggml_fopen(const char * fname, const char * mode) {
561
565
  #endif
562
566
 
563
567
  }
564
- static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
565
- static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
566
- static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
568
+ static void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
569
+ static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
570
+ static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
567
571
 
568
572
  static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
569
573
  [GGML_TYPE_I8] = {
@@ -2328,6 +2332,7 @@ struct ggml_tensor * ggml_concat(
2328
2332
  struct ggml_tensor * b,
2329
2333
  int dim) {
2330
2334
  GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
2335
+ GGML_ASSERT(a->type == b->type);
2331
2336
 
2332
2337
  int64_t ne[GGML_MAX_DIMS];
2333
2338
  for (int d = 0; d < GGML_MAX_DIMS; ++d) {