whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -1,6 +1,6 @@
1
1
  #include "ggml-vulkan.h"
2
2
  #include <vulkan/vulkan_core.h>
3
- #if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS)
3
+ #if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS)
4
4
  #include <chrono>
5
5
  #include "ggml-cpu.h"
6
6
  #endif
@@ -78,7 +78,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
78
78
  #define VK_VENDOR_ID_INTEL 0x8086
79
79
  #define VK_VENDOR_ID_NVIDIA 0x10de
80
80
 
81
- #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32
81
+ #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
82
82
 
83
83
  #define GGML_VK_MAX_NODES 8192
84
84
 
@@ -102,25 +102,11 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
102
102
 
103
103
  struct ggml_backend_vk_context;
104
104
 
105
- struct vk_queue {
106
- uint32_t queue_family_index;
107
- vk::Queue queue;
108
- vk::CommandPool pool;
109
- uint32_t cmd_buffer_idx;
110
- std::vector<vk::CommandBuffer> cmd_buffers;
111
-
112
- vk::PipelineStageFlags stage_flags;
113
-
114
- bool transfer_only;
115
- };
105
+ #define MAX_PARAMETER_COUNT 8
116
106
 
117
107
  struct vk_pipeline_struct {
118
108
  std::string name;
119
109
  vk::ShaderModule shader_module;
120
- vk::DescriptorSetLayout dsl;
121
- std::vector<vk::DescriptorPool> descriptor_pools;
122
- std::vector<vk::DescriptorSet> descriptor_sets;
123
- uint32_t descriptor_set_idx;
124
110
  vk::PipelineLayout layout;
125
111
  vk::Pipeline pipeline;
126
112
  uint32_t push_constant_size;
@@ -167,6 +153,45 @@ struct ggml_backend_vk_buffer_type_context {
167
153
  vk_device device;
168
154
  };
169
155
 
156
+ struct vk_queue;
157
+
158
+ // Stores command pool/buffers. There's an instance of this
159
+ // for each (context,queue) pair and for each (device,queue) pair.
160
+ struct vk_command_pool {
161
+ void init(vk_device& device, vk_queue *q_);
162
+ void destroy(vk::Device& device);
163
+
164
+ vk::CommandPool pool;
165
+ uint32_t cmd_buffer_idx;
166
+ std::vector<vk::CommandBuffer> cmd_buffers;
167
+
168
+ vk_queue *q;
169
+ };
170
+
171
+ // Prevent simultaneous submissions to the same queue.
172
+ // This could be per vk_queue if we stopped having two vk_queue structures
173
+ // sharing the same vk::Queue.
174
+ static std::mutex queue_mutex;
175
+
176
+ struct vk_queue {
177
+ uint32_t queue_family_index;
178
+ vk::Queue queue;
179
+
180
+ vk_command_pool cmd_pool;
181
+
182
+ vk::PipelineStageFlags stage_flags;
183
+
184
+ bool transfer_only;
185
+
186
+ // copy everything except the cmd_pool
187
+ void copyFrom(vk_queue &other) {
188
+ queue_family_index = other.queue_family_index;
189
+ queue = other.queue;
190
+ stage_flags = other.stage_flags;
191
+ transfer_only = other.transfer_only;
192
+ }
193
+ };
194
+
170
195
  static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft);
171
196
  static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
172
197
  static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft);
@@ -184,9 +209,7 @@ static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
184
209
  #ifdef GGML_VULKAN_MEMORY_DEBUG
185
210
  class vk_memory_logger;
186
211
  #endif
187
- #ifdef GGML_VULKAN_PERF
188
212
  class vk_perf_logger;
189
- #endif
190
213
  static void ggml_vk_destroy_buffer(vk_buffer& buf);
191
214
 
192
215
  static constexpr uint32_t mul_mat_vec_max_cols = 8;
@@ -198,6 +221,7 @@ enum vk_device_architecture {
198
221
  AMD_RDNA1,
199
222
  AMD_RDNA2,
200
223
  AMD_RDNA3,
224
+ INTEL_XE2,
201
225
  };
202
226
 
203
227
  static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
@@ -248,12 +272,40 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
248
272
  }
249
273
  return vk_device_architecture::AMD_RDNA2;
250
274
  }
275
+ } else if (props.vendorID == VK_VENDOR_ID_INTEL) {
276
+ const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
277
+
278
+ bool subgroup_size_control = false;
279
+
280
+ for (const auto& properties : ext_props) {
281
+ if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
282
+ subgroup_size_control = true;
283
+ }
284
+ }
285
+
286
+ if (!subgroup_size_control) {
287
+ return vk_device_architecture::OTHER;
288
+ }
289
+
290
+ vk::PhysicalDeviceProperties2 props2;
291
+ vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
292
+
293
+ props2.pNext = &subgroup_size_control_props;
294
+ device.getProperties2(&props2);
295
+
296
+ if (subgroup_size_control_props.minSubgroupSize == 16) {
297
+ // Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8.
298
+ // Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value.
299
+ // https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html
300
+ // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
301
+ return vk_device_architecture::INTEL_XE2;
302
+ }
251
303
  }
252
304
  return vk_device_architecture::OTHER;
253
305
  }
254
306
 
255
307
  struct vk_device_struct {
256
- std::mutex mutex;
308
+ std::recursive_mutex mutex;
257
309
 
258
310
  vk::PhysicalDevice physical_device;
259
311
  vk::PhysicalDeviceProperties properties;
@@ -314,6 +366,8 @@ struct vk_device_struct {
314
366
  // set to true to indicate that some shaders need to be compiled after the dryrun
315
367
  bool need_compiles {};
316
368
 
369
+ vk::DescriptorSetLayout dsl;
370
+
317
371
  vk_matmul_pipeline pipeline_matmul_f32 {};
318
372
  vk_matmul_pipeline pipeline_matmul_f32_f16 {};
319
373
  vk_matmul_pipeline pipeline_matmul_bf16 {};
@@ -371,6 +425,7 @@ struct vk_device_struct {
371
425
  vk_pipeline pipeline_norm_f32;
372
426
  vk_pipeline pipeline_group_norm_f32;
373
427
  vk_pipeline pipeline_rms_norm_f32;
428
+ vk_pipeline pipeline_rms_norm_mul_f32;
374
429
  vk_pipeline pipeline_rms_norm_back_f32;
375
430
  vk_pipeline pipeline_l2_norm_f32;
376
431
 
@@ -382,6 +437,10 @@ struct vk_device_struct {
382
437
  vk_pipeline pipeline_tanh[2];
383
438
  vk_pipeline pipeline_sigmoid[2];
384
439
 
440
+ vk_pipeline pipeline_geglu[2];
441
+ vk_pipeline pipeline_reglu[2];
442
+ vk_pipeline pipeline_swiglu[2];
443
+
385
444
  vk_pipeline pipeline_leaky_relu_f32;
386
445
  vk_pipeline pipeline_silu_back_f32;
387
446
  vk_pipeline pipeline_diag_mask_inf_f32;
@@ -398,6 +457,7 @@ struct vk_device_struct {
398
457
  vk_pipeline pipeline_count_equal_i32;
399
458
  vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
400
459
  vk_pipeline pipeline_timestep_embedding_f32;
460
+ vk_pipeline pipeline_conv_transpose_1d_f32;
401
461
  vk_pipeline pipeline_pool2d_f32;
402
462
  vk_pipeline pipeline_rwkv_wkv6_f32;
403
463
  vk_pipeline pipeline_rwkv_wkv7_f32;
@@ -430,7 +490,6 @@ struct vk_device_struct {
430
490
  vk_pipeline pipeline_flash_attn_split_k_reduce;
431
491
 
432
492
  std::unordered_map<std::string, vk_pipeline_ref> pipelines;
433
- std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
434
493
 
435
494
  std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
436
495
 
@@ -442,9 +501,11 @@ struct vk_device_struct {
442
501
  #ifdef GGML_VULKAN_MEMORY_DEBUG
443
502
  std::unique_ptr<vk_memory_logger> memory_logger;
444
503
  #endif
445
- #ifdef GGML_VULKAN_PERF
504
+
505
+ // for GGML_VK_PERF_LOGGER
446
506
  std::unique_ptr<vk_perf_logger> perf_logger;
447
- #endif
507
+ vk::QueryPool query_pool;
508
+ int32_t num_queries;
448
509
 
449
510
  ~vk_device_struct() {
450
511
  VK_LOG_DEBUG("destroy device " << name);
@@ -453,10 +514,8 @@ struct vk_device_struct {
453
514
 
454
515
  ggml_vk_destroy_buffer(sync_staging);
455
516
 
456
- device.destroyCommandPool(compute_queue.pool);
457
- if (!single_queue) {
458
- device.destroyCommandPool(transfer_queue.pool);
459
- }
517
+ compute_queue.cmd_pool.destroy(device);
518
+ transfer_queue.cmd_pool.destroy(device);
460
519
 
461
520
  for (auto& pipeline : pipelines) {
462
521
  if (pipeline.second.expired()) {
@@ -468,10 +527,26 @@ struct vk_device_struct {
468
527
  }
469
528
  pipelines.clear();
470
529
 
530
+ device.destroyDescriptorSetLayout(dsl);
531
+
471
532
  device.destroy();
472
533
  }
473
534
  };
474
535
 
536
+ void vk_command_pool::init(vk_device& device, vk_queue *q_) {
537
+ cmd_buffer_idx = 0;
538
+ q = q_;
539
+
540
+ vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index);
541
+ pool = device->device.createCommandPool(command_pool_create_info);
542
+ }
543
+
544
+ void vk_command_pool::destroy(vk::Device& device) {
545
+ device.destroyCommandPool(pool);
546
+ pool = nullptr;
547
+ cmd_buffers.clear();
548
+ }
549
+
475
550
  struct vk_buffer_struct {
476
551
  vk::Buffer buffer = VK_NULL_HANDLE;
477
552
  vk::DeviceMemory device_memory = VK_NULL_HANDLE;
@@ -590,6 +665,13 @@ struct vk_op_push_constants {
590
665
  float param2;
591
666
  };
592
667
 
668
+ struct vk_op_glu_push_constants {
669
+ uint32_t N;
670
+ uint32_t ne00;
671
+ uint32_t ne20;
672
+ uint32_t mode; // 0: default, 1: swapped, 2: split
673
+ };
674
+
593
675
  struct vk_op_unary_push_constants {
594
676
  uint32_t ne;
595
677
  uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -706,6 +788,21 @@ struct vk_op_timestep_embedding_push_constants {
706
788
  uint32_t max_period;
707
789
  };
708
790
 
791
+ struct vk_op_conv_transpose_1d_push_constants {
792
+ uint32_t Cout;
793
+ uint32_t Cin;
794
+ uint32_t K;
795
+ uint32_t L;
796
+ uint32_t KL;
797
+
798
+ uint32_t nb01;
799
+ uint32_t nb02;
800
+ uint32_t nb11;
801
+ uint32_t nb1;
802
+
803
+ int32_t s0;
804
+ };
805
+
709
806
  struct vk_op_pool2d_push_constants {
710
807
  uint32_t IW; uint32_t IH;
711
808
  uint32_t OW; uint32_t OH;
@@ -774,7 +871,7 @@ struct vk_context_struct {
774
871
  std::vector<vk_staging_memcpy> in_memcpys;
775
872
  std::vector<vk_staging_memcpy> out_memcpys;
776
873
 
777
- vk_queue * q;
874
+ vk_command_pool * p {};
778
875
  };
779
876
  typedef std::shared_ptr<vk_context_struct> vk_context;
780
877
  typedef std::weak_ptr<vk_context_struct> vk_context_ref;
@@ -828,8 +925,6 @@ private:
828
925
  #define VK_LOG_MEMORY(msg) ((void) 0)
829
926
  #endif // GGML_VULKAN_MEMORY_DEBUG
830
927
 
831
- #if defined(GGML_VULKAN_PERF)
832
-
833
928
  class vk_perf_logger {
834
929
  public:
835
930
  void print_timings() {
@@ -839,7 +934,7 @@ public:
839
934
  for (const auto& time : t.second) {
840
935
  total += time;
841
936
  }
842
- std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " ms" << std::endl;
937
+ std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl;
843
938
  }
844
939
 
845
940
  timings.clear();
@@ -868,7 +963,6 @@ public:
868
963
  private:
869
964
  std::map<std::string, std::vector<uint64_t>> timings;
870
965
  };
871
- #endif // GGML_VULKAN_PERF
872
966
 
873
967
  struct ggml_backend_vk_context {
874
968
  std::string name;
@@ -888,6 +982,18 @@ struct ggml_backend_vk_context {
888
982
  vk_context_ref transfer_ctx;
889
983
 
890
984
  std::vector<vk_context_ref> tensor_ctxs;
985
+
986
+ std::vector<vk::DescriptorPool> descriptor_pools;
987
+ std::vector<vk::DescriptorSet> descriptor_sets;
988
+ uint32_t descriptor_set_idx {};
989
+ uint32_t pipeline_descriptor_set_requirements {};
990
+
991
+ vk_command_pool compute_cmd_pool;
992
+ vk_command_pool transfer_cmd_pool;
993
+
994
+ // number of additional consecutive nodes that are being fused with the
995
+ // node currently being processed
996
+ uint32_t num_additional_fused_ops {};
891
997
  };
892
998
 
893
999
  static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@@ -951,6 +1057,14 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
951
1057
  struct vk_instance_t {
952
1058
  vk::Instance instance;
953
1059
 
1060
+ bool debug_utils_support = false; // VK_EXT_debug_utils enabled
1061
+ PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {};
1062
+ PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {};
1063
+ PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {};
1064
+ PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {};
1065
+ PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {};
1066
+ PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {};
1067
+
954
1068
  std::vector<size_t> device_indices;
955
1069
  vk_device devices[GGML_VK_MAX_DEVICES];
956
1070
  };
@@ -958,6 +1072,8 @@ struct vk_instance_t {
958
1072
  static bool vk_instance_initialized = false;
959
1073
  static vk_instance_t vk_instance;
960
1074
 
1075
+ static bool vk_perf_logger_enabled = false;
1076
+
961
1077
  #ifdef GGML_VULKAN_CHECK_RESULTS
962
1078
  static size_t vk_skip_checks;
963
1079
  static size_t vk_output_tensor;
@@ -1016,39 +1132,19 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
1016
1132
  ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " <<
1017
1133
  disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
1018
1134
  GGML_ASSERT(parameter_count > 0);
1135
+ GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT);
1019
1136
  GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
1020
1137
 
1021
1138
  vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
1022
1139
  pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
1023
1140
 
1024
- std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
1025
- std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
1026
- for (uint32_t i = 0; i < parameter_count; i++) {
1027
- dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute});
1028
- dsl_binding_flags.push_back({});
1029
- }
1030
-
1031
- vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags };
1032
-
1033
1141
  vk::PushConstantRange pcr(
1034
1142
  vk::ShaderStageFlagBits::eCompute,
1035
1143
  0,
1036
1144
  pipeline->push_constant_size
1037
1145
  );
1038
1146
 
1039
- vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
1040
- {},
1041
- dsl_binding);
1042
- descriptor_set_layout_create_info.setPNext(&dslbfci);
1043
- pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
1044
-
1045
- vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
1046
- vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
1047
- pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
1048
-
1049
- pipeline->descriptor_set_idx = 0;
1050
-
1051
- vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr);
1147
+ vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr);
1052
1148
  pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info);
1053
1149
 
1054
1150
  std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());
@@ -1108,8 +1204,16 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
1108
1204
  }
1109
1205
  pipeline->compiled = true;
1110
1206
 
1207
+ if (vk_instance.debug_utils_support) {
1208
+ vk::DebugUtilsObjectNameInfoEXT duoni;
1209
+ duoni.objectType = vk::ObjectType::ePipeline;
1210
+ duoni.pObjectName = pipeline->name.c_str();
1211
+ duoni.objectHandle = reinterpret_cast<uint64_t>(static_cast<VkPipeline_T*>(pipeline->pipeline));
1212
+ vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
1213
+ }
1214
+
1111
1215
  {
1112
- std::lock_guard<std::mutex> guard(device->mutex);
1216
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
1113
1217
  device->pipelines.insert({ pipeline->name, pipeline });
1114
1218
  }
1115
1219
 
@@ -1123,15 +1227,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
1123
1227
 
1124
1228
  static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
1125
1229
  VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")");
1126
- for (auto& pool : pipeline->descriptor_pools) {
1127
- device.destroyDescriptorPool(pool);
1128
- }
1129
- pipeline->descriptor_pools.clear();
1130
- pipeline->descriptor_sets.clear();
1131
- pipeline->descriptor_set_idx = 0;
1132
-
1133
- device.destroyDescriptorSetLayout(pipeline->dsl);
1134
-
1135
1230
  device.destroyPipelineLayout(pipeline->layout);
1136
1231
 
1137
1232
  device.destroyShaderModule(pipeline->shader_module);
@@ -1139,97 +1234,77 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline)
1139
1234
  device.destroyPipeline(pipeline->pipeline);
1140
1235
  }
1141
1236
 
1142
- static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) {
1237
+ static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) {
1143
1238
  VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
1144
- device->pipeline_descriptor_set_requirements[pipeline->name] += n;
1239
+ ctx->pipeline_descriptor_set_requirements += n;
1145
1240
  if (!pipeline->compiled) {
1146
1241
  pipeline->needed = true;
1147
- device->need_compiles = true;
1242
+ ctx->device->need_compiles = true;
1148
1243
  }
1149
1244
  }
1150
1245
 
1151
- static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) {
1152
- std::lock_guard<std::mutex> guard(device->mutex);
1153
-
1154
- for (auto& pair : device->pipeline_descriptor_set_requirements) {
1155
- vk_pipeline pipeline = device->pipelines.at(pair.first).lock();
1156
- const uint64_t n = pair.second;
1246
+ static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) {
1157
1247
 
1158
- VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")");
1159
-
1160
- if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) {
1161
- // Enough descriptors are available
1162
- continue;
1163
- }
1248
+ if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) {
1249
+ // Enough descriptors are available
1250
+ return;
1251
+ }
1164
1252
 
1165
- uint32_t to_alloc = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size();
1166
- uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - pipeline->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE;
1167
- uint32_t pool_idx = pipeline->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE;
1253
+ vk_device& device = ctx->device;
1168
1254
 
1169
- while (to_alloc > 0) {
1170
- const uint32_t alloc_count = std::min(pool_remaining, to_alloc);
1171
- to_alloc -= alloc_count;
1172
- pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE;
1255
+ uint32_t to_alloc = ctx->pipeline_descriptor_set_requirements - ctx->descriptor_sets.size();
1256
+ uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE;
1257
+ uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE;
1173
1258
 
1174
- if (pool_idx >= pipeline->descriptor_pools.size()) {
1175
- vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
1176
- vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
1177
- pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
1178
- }
1259
+ while (to_alloc > 0) {
1260
+ const uint32_t alloc_count = std::min(pool_remaining, to_alloc);
1261
+ to_alloc -= alloc_count;
1262
+ pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE;
1179
1263
 
1180
- std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
1181
- for (uint32_t i = 0; i < alloc_count; i++) {
1182
- layouts[i] = pipeline->dsl;
1183
- }
1184
- vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[pool_idx], alloc_count, layouts.data());
1185
- std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
1186
- pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end());
1264
+ if (pool_idx >= ctx->descriptor_pools.size()) {
1265
+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
1266
+ vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
1267
+ ctx->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
1268
+ }
1187
1269
 
1188
- pool_idx++;
1270
+ std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
1271
+ for (uint32_t i = 0; i < alloc_count; i++) {
1272
+ layouts[i] = device->dsl;
1189
1273
  }
1190
- }
1191
- }
1274
+ vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data());
1275
+ std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
1276
+ ctx->descriptor_sets.insert(ctx->descriptor_sets.end(), sets.begin(), sets.end());
1192
1277
 
1193
- static void ggml_pipeline_cleanup(vk_pipeline& pipeline) {
1194
- VK_LOG_DEBUG("ggml_pipeline_cleanup(" << pipeline->name << ")");
1195
- pipeline->descriptor_set_idx = 0;
1278
+ pool_idx++;
1279
+ }
1196
1280
  }
1197
1281
 
1198
- static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_queue& q) {
1282
+ static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {
1199
1283
  VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
1200
- std::lock_guard<std::mutex> guard(device->mutex);
1201
1284
 
1202
- if (q.cmd_buffers.size() > q.cmd_buffer_idx) {
1285
+ if (p.cmd_buffers.size() > p.cmd_buffer_idx) {
1203
1286
  // Reuse command buffer
1204
- return q.cmd_buffers[q.cmd_buffer_idx++];
1287
+ return p.cmd_buffers[p.cmd_buffer_idx++];
1205
1288
  }
1206
1289
 
1207
1290
  vk::CommandBufferAllocateInfo command_buffer_alloc_info(
1208
- q.pool,
1291
+ p.pool,
1209
1292
  vk::CommandBufferLevel::ePrimary,
1210
1293
  1);
1211
1294
  const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
1212
1295
  auto buf = cmd_buffers.front();
1213
1296
 
1214
- q.cmd_buffers.push_back(buf);
1215
- q.cmd_buffer_idx++;
1297
+ p.cmd_buffers.push_back(buf);
1298
+ p.cmd_buffer_idx++;
1216
1299
 
1217
1300
  return buf;
1218
1301
  }
1219
1302
 
1220
- static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
1221
- VK_LOG_DEBUG("ggml_vk_create_submission()");
1222
- vk_submission s;
1223
- s.buffer = ggml_vk_create_cmd_buffer(device, q);
1224
- s.wait_semaphores = std::move(wait_semaphores);
1225
- s.signal_semaphores = std::move(signal_semaphores);
1226
- return s;
1227
- }
1228
-
1229
1303
  static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
1230
1304
  if (ctx->seqs.empty()) {
1231
1305
  if (fence) {
1232
- ctx->q->queue.submit({}, fence);
1306
+ std::lock_guard<std::mutex> guard(queue_mutex);
1307
+ ctx->p->q->queue.submit({}, fence);
1233
1308
  }
1234
1309
  return;
1235
1310
  }
@@ -1268,7 +1343,7 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
1268
1343
  tl_signal_vals.push_back({});
1269
1344
  tl_signal_semaphores.push_back({});
1270
1345
  for (size_t i = 0; i < submission.wait_semaphores.size(); i++) {
1271
- stage_flags[idx].push_back(ctx->q->stage_flags);
1346
+ stage_flags[idx].push_back(ctx->p->q->stage_flags);
1272
1347
  tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value);
1273
1348
  tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s);
1274
1349
  }
@@ -1298,7 +1373,8 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
1298
1373
  }
1299
1374
  }
1300
1375
 
1301
- ctx->q->queue.submit(submit_infos, fence);
1376
+ std::lock_guard<std::mutex> guard(queue_mutex);
1377
+ ctx->p->q->queue.submit(submit_infos, fence);
1302
1378
 
1303
1379
  ctx->seqs.clear();
1304
1380
  }
@@ -1351,33 +1427,30 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyPrope
1351
1427
 
1352
1428
  static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
1353
1429
  VK_LOG_DEBUG("ggml_vk_create_queue()");
1354
- std::lock_guard<std::mutex> guard(device->mutex);
1430
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
1355
1431
 
1356
1432
  q.queue_family_index = queue_family_index;
1357
1433
  q.transfer_only = transfer_only;
1358
1434
 
1359
- vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index);
1360
- q.pool = device->device.createCommandPool(command_pool_create_info_compute);
1361
-
1362
- q.cmd_buffer_idx = 0;
1435
+ q.cmd_pool.init(device, &q);
1363
1436
 
1364
1437
  q.queue = device->device.getQueue(queue_family_index, queue_index);
1365
1438
 
1366
1439
  q.stage_flags = stage_flags;
1367
1440
  }
1368
1441
 
1369
- static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) {
1442
+ static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) {
1370
1443
  vk_context result = std::make_shared<vk_context_struct>();
1371
1444
  VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")");
1372
1445
  ctx->gc.contexts.emplace_back(result);
1373
- result->q = &q;
1446
+ result->p = &p;
1374
1447
  return result;
1375
1448
  }
1376
1449
 
1377
- static vk_context ggml_vk_create_temporary_context(vk_queue& q) {
1450
+ static vk_context ggml_vk_create_temporary_context(vk_command_pool& p) {
1378
1451
  vk_context result = std::make_shared<vk_context_struct>();
1379
1452
  VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")");
1380
- result->q = &q;
1453
+ result->p = &p;
1381
1454
  return result;
1382
1455
  }
1383
1456
 
@@ -1410,15 +1483,29 @@ static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
1410
1483
  return ctx->gc.events[ctx->event_idx++];
1411
1484
  }
1412
1485
 
1413
- static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) {
1414
- VK_LOG_DEBUG("ggml_vk_queue_cleanup()");
1415
- std::lock_guard<std::mutex> guard(device->mutex);
1486
+ static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) {
1487
+ VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()");
1416
1488
 
1417
1489
  // Requires command buffers to be done
1418
- device->device.resetCommandPool(q.pool);
1419
- q.cmd_buffer_idx = 0;
1490
+ device->device.resetCommandPool(p.pool);
1491
+ p.cmd_buffer_idx = 0;
1420
1492
  }
1421
1493
 
1494
+ static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
1495
+ VK_LOG_DEBUG("ggml_vk_queue_command_pools_cleanup()");
1496
+
1497
+ // Arbitrary frequency to cleanup/reuse command buffers
1498
+ static constexpr uint32_t cleanup_frequency = 10;
1499
+
1500
+ if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
1501
+ ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool);
1502
+ }
1503
+ if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
1504
+ ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool);
1505
+ }
1506
+ }
1507
+
1508
+
1422
1509
  static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
1423
1510
  for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
1424
1511
  vk::MemoryType memory_type = mem_props->memoryTypes[i];
@@ -1437,8 +1524,6 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
1437
1524
  throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit");
1438
1525
  }
1439
1526
 
1440
- std::lock_guard<std::mutex> guard(device->mutex);
1441
-
1442
1527
  vk_buffer buf = std::make_shared<vk_buffer_struct>();
1443
1528
 
1444
1529
  if (size == 0) {
@@ -1567,11 +1652,11 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) {
1567
1652
  static void ggml_vk_sync_buffers(vk_context& ctx) {
1568
1653
  VK_LOG_DEBUG("ggml_vk_sync_buffers()");
1569
1654
 
1570
- const bool transfer_queue = ctx->q->transfer_only;
1655
+ const bool transfer_queue = ctx->p->q->transfer_only;
1571
1656
 
1572
1657
  ctx->s->buffer.pipelineBarrier(
1573
- ctx->q->stage_flags,
1574
- ctx->q->stage_flags,
1658
+ ctx->p->q->stage_flags,
1659
+ ctx->p->q->stage_flags,
1575
1660
  {},
1576
1661
  { {
1577
1662
  { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },
@@ -1590,8 +1675,8 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
1590
1675
 
1591
1676
  ctx->s->buffer.waitEvents(
1592
1677
  events,
1593
- ctx->q->stage_flags,
1594
- ctx->q->stage_flags,
1678
+ ctx->p->q->stage_flags,
1679
+ ctx->p->q->stage_flags,
1595
1680
  {},
1596
1681
  {},
1597
1682
  {}
@@ -1653,7 +1738,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
1653
1738
  return {64, 32};
1654
1739
  }
1655
1740
  return {64, 64};
1656
- };
1741
+ }
1657
1742
 
1658
1743
  static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
1659
1744
 
@@ -2586,7 +2671,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2586
2671
 
2587
2672
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2588
2673
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2589
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2674
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2675
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
2590
2676
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2591
2677
  ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2592
2678
 
@@ -2682,6 +2768,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
2682
2768
  CREATE_UNARY(sigmoid)
2683
2769
  #undef CREATE_UNARY
2684
2770
 
2771
+ #define CREATE_GLU(name) \
2772
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2773
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
2774
+
2775
+ CREATE_GLU(geglu)
2776
+ CREATE_GLU(reglu)
2777
+ CREATE_GLU(swiglu)
2778
+ #undef CREATE_GLU
2779
+
2685
2780
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2686
2781
  ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2687
2782
 
@@ -2727,6 +2822,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2727
2822
 
2728
2823
  ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
2729
2824
 
2825
+ ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
2826
+
2730
2827
  ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
2731
2828
 
2732
2829
  ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
@@ -2757,9 +2854,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2757
2854
  #ifdef GGML_VULKAN_MEMORY_DEBUG
2758
2855
  device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());
2759
2856
  #endif
2760
- #ifdef GGML_VULKAN_PERF
2761
- device->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
2762
- #endif
2857
+ if (vk_perf_logger_enabled) {
2858
+ device->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
2859
+ }
2763
2860
 
2764
2861
  size_t dev_num = vk_instance.device_indices[idx];
2765
2862
 
@@ -3323,6 +3420,22 @@ static vk_device ggml_vk_get_device(size_t idx) {
3323
3420
  }
3324
3421
  }
3325
3422
 
3423
+
3424
+ std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
3425
+ std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
3426
+ for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) {
3427
+ dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute});
3428
+ dsl_binding_flags.push_back({});
3429
+ }
3430
+
3431
+ vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags };
3432
+
3433
+ vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
3434
+ {},
3435
+ dsl_binding);
3436
+ descriptor_set_layout_create_info.setPNext(&dslbfci);
3437
+ device->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
3438
+
3326
3439
  ggml_vk_load_shaders(device);
3327
3440
 
3328
3441
  if (!device->single_queue) {
@@ -3330,7 +3443,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
3330
3443
  ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
3331
3444
  } else {
3332
3445
  // TODO: Use pointer or reference to avoid copy
3333
- device->transfer_queue = device->compute_queue;
3446
+ device->transfer_queue.copyFrom(device->compute_queue);
3447
+ device->transfer_queue.cmd_pool.init(device, &device->transfer_queue);
3334
3448
  }
3335
3449
 
3336
3450
  device->buffer_type = {
@@ -3489,6 +3603,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
3489
3603
  static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
3490
3604
  static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
3491
3605
 
3606
+ static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
3607
+
3492
3608
  static void ggml_vk_instance_init() {
3493
3609
  if (vk_instance_initialized) {
3494
3610
  return;
@@ -3509,7 +3625,7 @@ static void ggml_vk_instance_init() {
3509
3625
  #ifdef __APPLE__
3510
3626
  const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
3511
3627
  #endif
3512
-
3628
+ const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr;
3513
3629
  std::vector<const char*> layers;
3514
3630
 
3515
3631
  if (validation_ext) {
@@ -3524,6 +3640,9 @@ static void ggml_vk_instance_init() {
3524
3640
  extensions.push_back("VK_KHR_portability_enumeration");
3525
3641
  }
3526
3642
  #endif
3643
+ if (debug_utils_ext) {
3644
+ extensions.push_back("VK_EXT_debug_utils");
3645
+ }
3527
3646
  vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
3528
3647
  #ifdef __APPLE__
3529
3648
  if (portability_enumeration_ext) {
@@ -3547,11 +3666,25 @@ static void ggml_vk_instance_init() {
3547
3666
  vk_instance.instance = vk::createInstance(instance_create_info);
3548
3667
  vk_instance_initialized = true;
3549
3668
 
3669
+ if (debug_utils_ext) {
3670
+ vk_instance.debug_utils_support = true;
3671
+ vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT");
3672
+ vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT");
3673
+ vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT");
3674
+ vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
3675
+ vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
3676
+ vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
3677
+
3678
+ }
3679
+
3550
3680
  size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
3681
+ vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
3551
3682
 
3552
3683
  // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
3553
3684
  char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES");
3554
3685
  if (devices_env != nullptr) {
3686
+ size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
3687
+
3555
3688
  std::string devices(devices_env);
3556
3689
  std::replace(devices.begin(), devices.end(), ',', ' ');
3557
3690
 
@@ -3567,9 +3700,9 @@ static void ggml_vk_instance_init() {
3567
3700
  } else {
3568
3701
  std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
3569
3702
 
3570
- // Make sure at least one device exists
3703
+ // If no vulkan devices are found, return early
3571
3704
  if (devices.empty()) {
3572
- std::cerr << "ggml_vulkan: Error: No devices found." << std::endl;
3705
+ GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
3573
3706
  return;
3574
3707
  }
3575
3708
 
@@ -3652,9 +3785,20 @@ static void ggml_vk_instance_init() {
3652
3785
  }
3653
3786
  }
3654
3787
 
3655
- // If no dedicated GPUs found, fall back to GPU 0
3788
+ // If no dedicated GPUs found, fall back to the first non-CPU device.
3789
+ // If only CPU devices are available, return without devices.
3790
+ if (vk_instance.device_indices.empty()) {
3791
+ for (size_t i = 0; i < devices.size(); i++) {
3792
+ if (devices[i].getProperties().deviceType != vk::PhysicalDeviceType::eCpu) {
3793
+ vk_instance.device_indices.push_back(i);
3794
+ break;
3795
+ }
3796
+ }
3797
+ }
3798
+
3656
3799
  if (vk_instance.device_indices.empty()) {
3657
- vk_instance.device_indices.push_back(0);
3800
+ GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
3801
+ return;
3658
3802
  }
3659
3803
  }
3660
3804
  GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
@@ -3683,6 +3827,9 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
3683
3827
  ctx->fence = ctx->device->device.createFence({});
3684
3828
  ctx->almost_ready_fence = ctx->device->device.createFence({});
3685
3829
 
3830
+ ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);
3831
+ ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);
3832
+
3686
3833
  #ifdef GGML_VULKAN_CHECK_RESULTS
3687
3834
  const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
3688
3835
  vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
@@ -4003,6 +4150,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
4003
4150
  return nullptr;
4004
4151
  }
4005
4152
 
4153
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
4006
4154
  device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
4007
4155
 
4008
4156
  return buf->ptr;
@@ -4013,6 +4161,8 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
4013
4161
  return;
4014
4162
  }
4015
4163
  VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
4164
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
4165
+
4016
4166
  vk_buffer buf;
4017
4167
  size_t index;
4018
4168
  for (size_t i = 0; i < device->pinned_memory.size(); i++) {
@@ -4035,6 +4185,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
4035
4185
  }
4036
4186
 
4037
4187
  static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
4188
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
4038
4189
  buf = nullptr;
4039
4190
  buf_offset = 0;
4040
4191
  for (size_t i = 0; i < device->pinned_memory.size(); i++) {
@@ -4048,9 +4199,9 @@ static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf
4048
4199
  }
4049
4200
  }
4050
4201
 
4051
- static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bool one_time = true) {
4202
+ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
4052
4203
  vk_submission s;
4053
- s.buffer = ggml_vk_create_cmd_buffer(device, q);
4204
+ s.buffer = ggml_vk_create_cmd_buffer(device, p);
4054
4205
  if (one_time) {
4055
4206
  s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
4056
4207
  } else {
@@ -4060,7 +4211,33 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
4060
4211
  return s;
4061
4212
  }
4062
4213
 
4063
- static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
4214
+ template <typename T> size_t push_constant_size(const T &t) {
4215
+ static_assert(std::is_class<T>::value, "T must be a struct/class");
4216
+ GGML_UNUSED(t);
4217
+ return sizeof(T);
4218
+ }
4219
+ template <typename T> size_t push_constant_size(const std::vector<T> &t) {
4220
+ GGML_UNUSED(t);
4221
+ return sizeof(T) * t.size();
4222
+ }
4223
+ template <typename T, uint32_t N> size_t push_constant_size(const std::array<T, N> &t) {
4224
+ GGML_UNUSED(t);
4225
+ return sizeof(T) * N;
4226
+ }
4227
+
4228
+ template <typename T> const T *push_constant_data(const T &t) {
4229
+ static_assert(std::is_class<T>::value, "T must be a struct/class");
4230
+ return &t;
4231
+ }
4232
+ template <typename T> const T *push_constant_data(const std::vector<T> &t) {
4233
+ return t.data();
4234
+ }
4235
+ template <typename T, uint32_t N> const T *push_constant_data(const std::array<T, N> &t) {
4236
+ return t.data();
4237
+ }
4238
+
4239
+ template <typename T>
4240
+ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, const T &push_constants, std::array<uint32_t, 3> elements) {
4064
4241
  const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
4065
4242
  const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
4066
4243
  const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
@@ -4069,14 +4246,14 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
4069
4246
  std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
4070
4247
  }
4071
4248
  std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
4072
- GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size());
4073
- GGML_ASSERT(descriptor_buffer_infos.size() == pipeline->parameter_count);
4249
+ GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
4250
+ GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
4074
4251
 
4075
- vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++];
4252
+ vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
4076
4253
  vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
4077
4254
  ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
4078
4255
 
4079
- subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
4256
+ subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
4080
4257
  subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
4081
4258
  subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
4082
4259
  pipeline->layout,
@@ -4109,7 +4286,7 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
4109
4286
  ggml_vk_ctx_end(subctx);
4110
4287
  }
4111
4288
 
4112
- subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->q) });
4289
+ subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->p) });
4113
4290
  subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
4114
4291
  }
4115
4292
 
@@ -4310,7 +4487,9 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
4310
4487
  memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
4311
4488
  }
4312
4489
  } else {
4313
- vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
4490
+ std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
4491
+
4492
+ vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
4314
4493
  ggml_vk_ctx_begin(dst->device, subctx);
4315
4494
  ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
4316
4495
  ggml_vk_ctx_end(subctx);
@@ -4322,6 +4501,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
4322
4501
  ggml_vk_submit(subctx, dst->device->fence);
4323
4502
  VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
4324
4503
  dst->device->device.resetFences({ dst->device->fence });
4504
+ ggml_vk_queue_command_pools_cleanup(dst->device);
4325
4505
  }
4326
4506
  }
4327
4507
 
@@ -4398,7 +4578,9 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
4398
4578
 
4399
4579
  memcpy(dst, (uint8_t *) src->ptr + offset, size);
4400
4580
  } else {
4401
- vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
4581
+ std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
4582
+
4583
+ vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
4402
4584
  ggml_vk_ctx_begin(src->device, subctx);
4403
4585
  ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true);
4404
4586
  ggml_vk_ctx_end(subctx);
@@ -4406,6 +4588,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
4406
4588
  ggml_vk_submit(subctx, src->device->fence);
4407
4589
  VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
4408
4590
  src->device->device.resetFences({ src->device->fence });
4591
+ ggml_vk_queue_command_pools_cleanup(src->device);
4409
4592
 
4410
4593
  for (auto& cpy : subctx->out_memcpys) {
4411
4594
  memcpy(cpy.dst, cpy.src, cpy.n);
@@ -4425,15 +4608,17 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds
4425
4608
 
4426
4609
  static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
4427
4610
  if (src->device == dst->device) {
4611
+ std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
4428
4612
  VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
4429
4613
  // Copy within the device
4430
- vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
4614
+ vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
4431
4615
  ggml_vk_ctx_begin(src->device, subctx);
4432
4616
  ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
4433
4617
  ggml_vk_ctx_end(subctx);
4434
4618
  ggml_vk_submit(subctx, src->device->fence);
4435
4619
  VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
4436
4620
  src->device->device.resetFences({ src->device->fence });
4621
+ ggml_vk_queue_command_pools_cleanup(src->device);
4437
4622
  } else {
4438
4623
  VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
4439
4624
  // Copy device to device
@@ -4458,7 +4643,8 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t
4458
4643
  static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
4459
4644
  VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
4460
4645
 
4461
- vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
4646
+ std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
4647
+ vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
4462
4648
  ggml_vk_ctx_begin(dst->device, subctx);
4463
4649
  subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
4464
4650
  ggml_vk_ctx_end(subctx);
@@ -4466,6 +4652,7 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
4466
4652
  ggml_vk_submit(subctx, dst->device->fence);
4467
4653
  VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences");
4468
4654
  dst->device->device.resetFences({ dst->device->fence });
4655
+ ggml_vk_queue_command_pools_cleanup(dst->device);
4469
4656
  }
4470
4657
 
4471
4658
  static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
@@ -4539,7 +4726,7 @@ static void ggml_vk_matmul(
4539
4726
  ggml_vk_sync_buffers(subctx);
4540
4727
  if (split_k == 1) {
4541
4728
  const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
4542
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
4729
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
4543
4730
  return;
4544
4731
  }
4545
4732
 
@@ -4547,10 +4734,10 @@ static void ggml_vk_matmul(
4547
4734
 
4548
4735
  const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
4549
4736
  // Make sure enough workgroups get assigned for split k to work
4550
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
4737
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
4551
4738
  ggml_vk_sync_buffers(subctx);
4552
4739
  const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
4553
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
4740
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
4554
4741
  }
4555
4742
 
4556
4743
  static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
@@ -4598,7 +4785,7 @@ static void ggml_vk_matmul_id(
4598
4785
  ggml_vk_sync_buffers(subctx);
4599
4786
  const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
4600
4787
  nei0, nei1, nbi1, ne11, padded_n };
4601
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
4788
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as });
4602
4789
  }
4603
4790
 
4604
4791
  static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
@@ -4683,9 +4870,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
4683
4870
  // type size must be exactly 2 or 4.
4684
4871
  GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
4685
4872
  if ((ggml_type_size(src->type) % 4) == 0) {
4686
- return ctx->device->pipeline_contig_cpy_f32_f32;
4873
+ if (contig) {
4874
+ return ctx->device->pipeline_contig_cpy_f32_f32;
4875
+ } else {
4876
+ return ctx->device->pipeline_cpy_f32_f32;
4877
+ }
4687
4878
  } else {
4688
- return ctx->device->pipeline_contig_cpy_f16_f16;
4879
+ if (contig) {
4880
+ return ctx->device->pipeline_contig_cpy_f16_f16;
4881
+ } else {
4882
+ return ctx->device->pipeline_cpy_f16_f16;
4883
+ }
4689
4884
  }
4690
4885
  }
4691
4886
 
@@ -4719,7 +4914,7 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
4719
4914
  };
4720
4915
  init_pushconst_fastdiv(pc);
4721
4916
  ggml_vk_sync_buffers(subctx);
4722
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
4917
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
4723
4918
  }
4724
4919
 
4725
4920
  static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
@@ -4738,7 +4933,7 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
4738
4933
  vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
4739
4934
 
4740
4935
  ggml_vk_sync_buffers(subctx);
4741
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 });
4936
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
4742
4937
  }
4743
4938
 
4744
4939
  static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -4746,7 +4941,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4746
4941
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
4747
4942
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
4748
4943
  std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
4749
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
4944
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
4750
4945
  GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
4751
4946
 
4752
4947
  const uint64_t ne00 = src0->ne[0];
@@ -4879,18 +5074,18 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4879
5074
  }
4880
5075
 
4881
5076
  // Request descriptor sets
4882
- ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
5077
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
4883
5078
  if (qx_needs_dequant) {
4884
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
5079
+ ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
4885
5080
  }
4886
5081
  if (qy_needs_dequant) {
4887
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
5082
+ ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
4888
5083
  }
4889
5084
  if (quantize_y) {
4890
- ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1);
5085
+ ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
4891
5086
  }
4892
5087
  if (split_k > 1) {
4893
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
5088
+ ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1);
4894
5089
  }
4895
5090
  return;
4896
5091
  }
@@ -4938,7 +5133,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4938
5133
  } else if (qx_needs_dequant) {
4939
5134
  const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
4940
5135
  ggml_vk_sync_buffers(subctx);
4941
- ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
5136
+ ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
4942
5137
  }
4943
5138
  if (y_non_contig) {
4944
5139
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
@@ -4974,7 +5169,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
4974
5169
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
4975
5170
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
4976
5171
  std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)");
4977
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
5172
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
4978
5173
  GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
4979
5174
 
4980
5175
  const uint64_t ne00 = src0->ne[0];
@@ -5072,12 +5267,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
5072
5267
 
5073
5268
  // Request descriptor sets
5074
5269
  if (qx_needs_dequant) {
5075
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
5270
+ ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
5076
5271
  }
5077
5272
  if (qy_needs_dequant) {
5078
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
5273
+ ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
5079
5274
  }
5080
- ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1);
5275
+ ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
5081
5276
  return;
5082
5277
  }
5083
5278
 
@@ -5154,7 +5349,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
5154
5349
  ggml_vk_sync_buffers(subctx);
5155
5350
  ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
5156
5351
  { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
5157
- sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
5352
+ pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
5158
5353
  }
5159
5354
 
5160
5355
  static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -5210,7 +5405,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
5210
5405
 
5211
5406
  if (dryrun) {
5212
5407
  // Request descriptor sets
5213
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
5408
+ ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
5214
5409
  return;
5215
5410
  }
5216
5411
 
@@ -5242,7 +5437,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
5242
5437
  }
5243
5438
 
5244
5439
  ggml_vk_sync_buffers(subctx);
5245
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z });
5440
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z });
5246
5441
  }
5247
5442
 
5248
5443
  static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -5299,7 +5494,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
5299
5494
 
5300
5495
  if (dryrun) {
5301
5496
  // Request descriptor sets
5302
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
5497
+ ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
5303
5498
  return;
5304
5499
  }
5305
5500
 
@@ -5325,7 +5520,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
5325
5520
  const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
5326
5521
  ggml_vk_sync_buffers(subctx);
5327
5522
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
5328
- { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
5523
+ { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
5329
5524
  }
5330
5525
 
5331
5526
  static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -5486,12 +5681,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
5486
5681
  }
5487
5682
 
5488
5683
  // Request descriptor sets
5489
- ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
5684
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
5490
5685
  if (qx_needs_dequant) {
5491
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
5686
+ ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
5492
5687
  }
5493
5688
  if (qy_needs_dequant) {
5494
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
5689
+ ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
5495
5690
  }
5496
5691
  return;
5497
5692
  }
@@ -5541,7 +5736,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
5541
5736
  const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
5542
5737
  ggml_vk_sync_buffers(subctx);
5543
5738
  ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
5544
- { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
5739
+ { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
5545
5740
  }
5546
5741
  if (y_non_contig) {
5547
5742
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
@@ -5575,7 +5770,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
5575
5770
  std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
5576
5771
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
5577
5772
  std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
5578
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
5773
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
5579
5774
  GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
5580
5775
  GGML_ASSERT(ids->type == GGML_TYPE_I32);
5581
5776
 
@@ -5680,12 +5875,12 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
5680
5875
 
5681
5876
  // Request descriptor sets
5682
5877
  if (qx_needs_dequant) {
5683
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
5878
+ ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
5684
5879
  }
5685
5880
  if (qy_needs_dequant) {
5686
- ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
5881
+ ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
5687
5882
  }
5688
- ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1);
5883
+ ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
5689
5884
  return;
5690
5885
  }
5691
5886
 
@@ -5761,7 +5956,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
5761
5956
  ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
5762
5957
  { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
5763
5958
  vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } },
5764
- sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z });
5959
+ pc, { groups_x, (uint32_t)nei0, groups_z });
5765
5960
  }
5766
5961
 
5767
5962
  static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
@@ -6005,9 +6200,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6005
6200
 
6006
6201
  if (dryrun) {
6007
6202
  // Request descriptor sets
6008
- ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
6203
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
6009
6204
  if (split_k > 1) {
6010
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
6205
+ ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
6011
6206
  }
6012
6207
  return;
6013
6208
  }
@@ -6111,7 +6306,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6111
6306
  // there's no more than one tile of rows (i.e. workgroups_x would have been
6112
6307
  // one). We reuse workgroups_x to mean the number of splits, so we need to
6113
6308
  // cancel out the divide by wg_denoms[0].
6114
- sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
6309
+ pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
6115
6310
 
6116
6311
  ggml_vk_sync_buffers(subctx);
6117
6312
  const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
@@ -6120,7 +6315,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6120
6315
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6121
6316
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6122
6317
  },
6123
- pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 });
6318
+ pc2, { (uint32_t)ne1, 1, 1 });
6124
6319
  } else {
6125
6320
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6126
6321
  {
@@ -6130,7 +6325,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6130
6325
  vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
6131
6326
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6132
6327
  },
6133
- sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
6328
+ pc, { workgroups_x, workgroups_y, workgroups_z });
6134
6329
  }
6135
6330
  }
6136
6331
 
@@ -6261,7 +6456,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6261
6456
  return nullptr;
6262
6457
  case GGML_OP_RMS_NORM:
6263
6458
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6264
- return ctx->device->pipeline_rms_norm_f32;
6459
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
6265
6460
  }
6266
6461
  return nullptr;
6267
6462
  case GGML_OP_RMS_NORM_BACK:
@@ -6298,6 +6493,24 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6298
6493
  break;
6299
6494
  }
6300
6495
  return nullptr;
6496
+ case GGML_OP_GLU:
6497
+ if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6498
+ (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
6499
+ (src0->type != dst->type)) {
6500
+ return nullptr;
6501
+ }
6502
+
6503
+ switch (ggml_get_glu_op(dst)) {
6504
+ case GGML_GLU_OP_GEGLU:
6505
+ return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
6506
+ case GGML_GLU_OP_REGLU:
6507
+ return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
6508
+ case GGML_GLU_OP_SWIGLU:
6509
+ return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
6510
+ default:
6511
+ break;
6512
+ }
6513
+ return nullptr;
6301
6514
  case GGML_OP_DIAG_MASK_INF:
6302
6515
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6303
6516
  return ctx->device->pipeline_diag_mask_inf_f32;
@@ -6391,6 +6604,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6391
6604
  return ctx->device->pipeline_timestep_embedding_f32;
6392
6605
  }
6393
6606
  return nullptr;
6607
+ case GGML_OP_CONV_TRANSPOSE_1D:
6608
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6609
+ return ctx->device->pipeline_conv_transpose_1d_f32;
6610
+ }
6611
+ return nullptr;
6394
6612
  case GGML_OP_POOL_2D:
6395
6613
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6396
6614
  return ctx->device->pipeline_pool2d_f32;
@@ -6565,7 +6783,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6565
6783
  }
6566
6784
 
6567
6785
  if (dryrun) {
6568
- ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
6786
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
6569
6787
  return;
6570
6788
  }
6571
6789
 
@@ -6725,6 +6943,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6725
6943
  uint32_t half_ceil = (dim + 1) / 2;
6726
6944
  elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
6727
6945
  } break;
6946
+ case GGML_OP_CONV_TRANSPOSE_1D:
6947
+ {
6948
+ elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
6949
+ } break;
6728
6950
  case GGML_OP_POOL_2D:
6729
6951
  {
6730
6952
  const uint32_t N = dst->ne[3];
@@ -6749,6 +6971,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6749
6971
  case GGML_OP_CONCAT:
6750
6972
  case GGML_OP_UPSCALE:
6751
6973
  case GGML_OP_UNARY:
6974
+ case GGML_OP_GLU:
6752
6975
  case GGML_OP_CONV_2D_DW:
6753
6976
  {
6754
6977
  uint32_t ne = ggml_nelements(dst);
@@ -6789,7 +7012,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6789
7012
  }
6790
7013
  }
6791
7014
 
6792
- if (op == GGML_OP_SOFT_MAX) {
7015
+ if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
6793
7016
  // Empty src1 is possible in soft_max, but the shader needs a buffer
6794
7017
  vk_subbuffer subbuf_y;
6795
7018
  if (use_src1) {
@@ -6799,7 +7022,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6799
7022
  }
6800
7023
 
6801
7024
  ggml_vk_sync_buffers(subctx);
6802
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
7025
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
6803
7026
  } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
6804
7027
  // Empty src2 is possible in rope, but the shader needs a buffer
6805
7028
  vk_subbuffer subbuf_z;
@@ -6810,26 +7033,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6810
7033
  }
6811
7034
 
6812
7035
  ggml_vk_sync_buffers(subctx);
6813
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
7036
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
6814
7037
  } else if (op == GGML_OP_IM2COL) {
6815
7038
  // im2col uses only src1 and dst buffers
6816
7039
  ggml_vk_sync_buffers(subctx);
6817
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
7040
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
6818
7041
  } else if (op == GGML_OP_COUNT_EQUAL) {
6819
7042
  ggml_vk_sync_buffers(subctx);
6820
7043
  // count_equal assumes that destination buffer is initialized with zeroes
6821
7044
  ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
6822
7045
  ggml_vk_sync_buffers(subctx);
6823
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
7046
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
6824
7047
  } else if (use_src2) {
6825
7048
  ggml_vk_sync_buffers(subctx);
6826
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
7049
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
6827
7050
  } else if (use_src1) {
6828
7051
  ggml_vk_sync_buffers(subctx);
6829
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
7052
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
6830
7053
  } else {
6831
7054
  ggml_vk_sync_buffers(subctx);
6832
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
7055
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
6833
7056
  }
6834
7057
  }
6835
7058
 
@@ -6942,7 +7165,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
6942
7165
  GGML_ASSERT(pipeline != nullptr);
6943
7166
 
6944
7167
  if (dryrun) {
6945
- ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
7168
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
6946
7169
  return;
6947
7170
  }
6948
7171
 
@@ -6998,7 +7221,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
6998
7221
  vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
6999
7222
  vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
7000
7223
  vk_subbuffer{ d_D, dst_offset, dst_size }
7001
- }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
7224
+ }, pc, elements);
7002
7225
  } else if (version == 7) {
7003
7226
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
7004
7227
  vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
@@ -7009,7 +7232,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
7009
7232
  vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
7010
7233
  vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
7011
7234
  vk_subbuffer{ d_D, dst_offset, dst_size }
7012
- }, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
7235
+ }, pc, elements);
7013
7236
  } else {
7014
7237
  // shouldn't happen
7015
7238
  GGML_ASSERT(false);
@@ -7081,7 +7304,7 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
7081
7304
  GGML_ASSERT(pipeline != nullptr);
7082
7305
 
7083
7306
  if (dryrun) {
7084
- ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
7307
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
7085
7308
  return;
7086
7309
  }
7087
7310
 
@@ -7146,7 +7369,7 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
7146
7369
  vk_subbuffer{ d_GM, gm_offset, gm_size },
7147
7370
  vk_subbuffer{ d_GV, gv_offset, gv_size },
7148
7371
  vk_subbuffer{ d_P, p_offset, p_size },
7149
- }, sizeof(vk_op_push_constants), &pc, elements);
7372
+ }, pc, elements);
7150
7373
  }
7151
7374
 
7152
7375
  static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
@@ -7352,18 +7575,19 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
7352
7575
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
7353
7576
  }
7354
7577
 
7355
- static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7578
+ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7356
7579
  float * op_params = (float *)dst->op_params;
7357
7580
  const uint32_t src0_type_size = ggml_type_size(src0->type);
7581
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
7358
7582
  const uint32_t dst_type_size = ggml_type_size(dst->type);
7359
7583
 
7360
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
7584
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
7361
7585
  (uint32_t)ggml_nelements(src0),
7362
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7363
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7586
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7587
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7588
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7364
7589
  0,
7365
- op_params[0], 0.0f,
7366
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7590
+ op_params[0], 0.0f, 0,
7367
7591
  }, dryrun);
7368
7592
  }
7369
7593
 
@@ -7381,6 +7605,25 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
7381
7605
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7382
7606
  }
7383
7607
 
7608
+ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7609
+ const bool swapped = (bool)dst->op_params[1];
7610
+ const bool split = src1 != nullptr;
7611
+
7612
+ GGML_ASSERT(ggml_is_contiguous(src0));
7613
+
7614
+ if (!split) {
7615
+ GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7616
+ } else {
7617
+ GGML_ASSERT(src0->ne[0] == src1->ne[0]);
7618
+ GGML_ASSERT(src0->ne[0] == dst->ne[0]);
7619
+ GGML_ASSERT(src0->type == src1->type);
7620
+ }
7621
+
7622
+ const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
7623
+
7624
+ ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
7625
+ }
7626
+
7384
7627
  static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7385
7628
  int32_t * op_params = (int32_t *)dst->op_params;
7386
7629
  ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
@@ -7528,6 +7771,37 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
7528
7771
  }, dryrun);
7529
7772
  }
7530
7773
 
7774
+ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7775
+ // src0: (K, Cout, Cin, 1) -- kernel
7776
+ // src1: (L, Cin, 1, 1) -- input
7777
+ // dst: (*, Cout, 1, 1)
7778
+
7779
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
7780
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
7781
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
7782
+
7783
+ GGML_TENSOR_BINARY_OP_LOCALS
7784
+
7785
+ GGML_ASSERT(nb00 == sizeof(float));
7786
+ GGML_ASSERT(nb10 == sizeof(float));
7787
+
7788
+ const int32_t s0 = dst->op_params[0];
7789
+
7790
+ vk_op_conv_transpose_1d_push_constants p{};
7791
+ p.Cout = static_cast<uint32_t>(ne01);
7792
+ p.Cin = static_cast<uint32_t>(ne02);
7793
+ p.K = static_cast<uint32_t>(ne00);
7794
+ p.L = static_cast<uint32_t>(ne10);
7795
+ p.KL = static_cast<uint32_t>(ne0);
7796
+ p.nb01 = static_cast<uint32_t>(nb01 / nb00);
7797
+ p.nb02 = static_cast<uint32_t>(nb02 / nb00);
7798
+ p.nb11 = static_cast<uint32_t>(nb11 / nb10);
7799
+ p.nb1 = static_cast<uint32_t>(nb1 / nb0);
7800
+ p.s0 = static_cast<uint32_t>(s0);
7801
+
7802
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun);
7803
+ }
7804
+
7531
7805
  static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7532
7806
  uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
7533
7807
  const int32_t k1 = dst->op_params[1];
@@ -7728,9 +8002,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
7728
8002
  }
7729
8003
  }
7730
8004
 
7731
- ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
8005
+ ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
7732
8006
  if (split_k > 1) {
7733
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
8007
+ ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
7734
8008
 
7735
8009
  if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
7736
8010
  // Resize buffer
@@ -7745,7 +8019,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
7745
8019
  ggml_vk_load_shaders(ctx->device);
7746
8020
  }
7747
8021
 
7748
- ggml_pipeline_allocate_descriptor_sets(ctx->device);
8022
+ ggml_pipeline_allocate_descriptor_sets(ctx);
7749
8023
 
7750
8024
  vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
7751
8025
  vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -7787,7 +8061,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
7787
8061
  ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
7788
8062
  ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
7789
8063
 
7790
- vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
8064
+ vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
7791
8065
  ggml_vk_ctx_begin(ctx->device, subctx);
7792
8066
  for (size_t i = 0; i < num_it; i++) {
7793
8067
  ggml_vk_matmul(
@@ -7803,6 +8077,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
7803
8077
  ggml_vk_submit(subctx, ctx->fence);
7804
8078
  VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
7805
8079
  ctx->device->device.resetFences({ ctx->fence });
8080
+ ggml_vk_queue_command_pools_cleanup(ctx->device);
7806
8081
 
7807
8082
  auto end = std::chrono::high_resolution_clock::now();
7808
8083
  double time = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
@@ -7904,16 +8179,13 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
7904
8179
 
7905
8180
  free(d_chk);
7906
8181
 
7907
- ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue);
7908
- ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue);
8182
+ ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
8183
+ ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
7909
8184
 
7910
8185
  ggml_vk_destroy_buffer(d_X);
7911
8186
  ggml_vk_destroy_buffer(d_Y);
7912
8187
  ggml_vk_destroy_buffer(d_D);
7913
8188
 
7914
- ggml_pipeline_cleanup(p);
7915
- ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce);
7916
-
7917
8189
  free(x);
7918
8190
  free(y);
7919
8191
  free(d);
@@ -7991,20 +8263,20 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
7991
8263
  ggml_vk_quantize_data(x, qx, ne, quant);
7992
8264
  ggml_vk_dequantize_data(qx, x_ref, ne, quant);
7993
8265
 
7994
- ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
8266
+ ggml_pipeline_request_descriptor_sets(ctx, p, 1);
7995
8267
 
7996
8268
  if (ctx->device->need_compiles) {
7997
8269
  ggml_vk_load_shaders(ctx->device);
7998
8270
  }
7999
8271
 
8000
- ggml_pipeline_allocate_descriptor_sets(ctx->device);
8272
+ ggml_pipeline_allocate_descriptor_sets(ctx);
8001
8273
 
8002
8274
  ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
8003
8275
 
8004
- vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
8276
+ vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
8005
8277
  ggml_vk_ctx_begin(ctx->device, subctx);
8006
8278
  const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
8007
- ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
8279
+ ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1});
8008
8280
  ggml_vk_ctx_end(subctx);
8009
8281
 
8010
8282
  auto begin = std::chrono::high_resolution_clock::now();
@@ -8012,6 +8284,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
8012
8284
  ggml_vk_submit(subctx, ctx->fence);
8013
8285
  VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
8014
8286
  ctx->device->device.resetFences({ ctx->fence });
8287
+ ggml_vk_queue_command_pools_cleanup(ctx->device);
8015
8288
 
8016
8289
  auto end = std::chrono::high_resolution_clock::now();
8017
8290
 
@@ -8091,17 +8364,17 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
8091
8364
  //
8092
8365
  // vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
8093
8366
  //
8094
- // ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
8367
+ // ggml_pipeline_request_descriptor_sets(ctx, p, 1);
8095
8368
  //
8096
8369
  // if (ctx->device->need_compiles) {
8097
8370
  // ggml_vk_load_shaders(ctx->device);
8098
8371
  // }
8099
8372
  //
8100
- // ggml_pipeline_allocate_descriptor_sets(ctx->device);
8373
+ // ggml_pipeline_allocate_descriptor_sets(ctx);
8101
8374
  //
8102
8375
  // ggml_vk_buffer_write(x_buf, 0, x, x_sz);
8103
8376
  //
8104
- // vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
8377
+ // vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
8105
8378
  // ggml_vk_ctx_begin(ctx->device, subctx);
8106
8379
  // ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
8107
8380
  // ggml_vk_ctx_end(subctx);
@@ -8111,6 +8384,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
8111
8384
  // ggml_vk_submit(subctx, ctx->fence);
8112
8385
  // VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
8113
8386
  // ctx->device->device.resetFences({ ctx->fence });
8387
+ // ggml_vk_queue_command_pools_cleanup(ctx->device);
8114
8388
  //
8115
8389
  // auto end = std::chrono::high_resolution_clock::now();
8116
8390
  //
@@ -8250,9 +8524,9 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
8250
8524
  // y[i] = i % k;
8251
8525
  }
8252
8526
 
8253
- ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
8527
+ ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
8254
8528
  if (split_k > 1) {
8255
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
8529
+ ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
8256
8530
 
8257
8531
  if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
8258
8532
  // Resize buffer
@@ -8263,19 +8537,19 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
8263
8537
  }
8264
8538
  }
8265
8539
  if (mmq) {
8266
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it);
8540
+ ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it);
8267
8541
  }
8268
8542
 
8269
8543
  if (ctx->device->need_compiles) {
8270
8544
  ggml_vk_load_shaders(ctx->device);
8271
8545
  }
8272
8546
 
8273
- ggml_pipeline_allocate_descriptor_sets(ctx->device);
8547
+ ggml_pipeline_allocate_descriptor_sets(ctx);
8274
8548
 
8275
8549
  ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
8276
8550
  ggml_vk_buffer_write(y_buf, 0, y, y_sz);
8277
8551
 
8278
- vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
8552
+ vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
8279
8553
  ggml_vk_ctx_begin(ctx->device, subctx);
8280
8554
  if (mmq) {
8281
8555
  for (size_t i = 0; i < num_it; i++) {
@@ -8304,6 +8578,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
8304
8578
  ggml_vk_submit(subctx, ctx->fence);
8305
8579
  VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
8306
8580
  ctx->device->device.resetFences({ ctx->fence });
8581
+ ggml_vk_queue_command_pools_cleanup(ctx->device);
8307
8582
 
8308
8583
  auto end = std::chrono::high_resolution_clock::now();
8309
8584
 
@@ -8526,7 +8801,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* t
8526
8801
 
8527
8802
  // Returns true if node has enqueued work into the queue, false otherwise
8528
8803
  // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
8529
- static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
8804
+ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
8805
+ ggml_tensor * node = cgraph->nodes[node_idx];
8530
8806
  if (ggml_is_empty(node) || !node->buffer) {
8531
8807
  return false;
8532
8808
  }
@@ -8560,6 +8836,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8560
8836
  return false;
8561
8837
  }
8562
8838
  break;
8839
+ case GGML_OP_GLU:
8840
+ switch (ggml_get_glu_op(node)) {
8841
+ case GGML_GLU_OP_GEGLU:
8842
+ case GGML_GLU_OP_REGLU:
8843
+ case GGML_GLU_OP_SWIGLU:
8844
+ break;
8845
+ default:
8846
+ return false;
8847
+ }
8848
+ break;
8563
8849
  case GGML_OP_REPEAT:
8564
8850
  case GGML_OP_REPEAT_BACK:
8565
8851
  case GGML_OP_GET_ROWS:
@@ -8599,6 +8885,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8599
8885
  case GGML_OP_COUNT_EQUAL:
8600
8886
  case GGML_OP_IM2COL:
8601
8887
  case GGML_OP_TIMESTEP_EMBEDDING:
8888
+ case GGML_OP_CONV_TRANSPOSE_1D:
8602
8889
  case GGML_OP_POOL_2D:
8603
8890
  case GGML_OP_CONV_2D_DW:
8604
8891
  case GGML_OP_RWKV_WKV6:
@@ -8617,7 +8904,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8617
8904
 
8618
8905
  if (!dryrun) {
8619
8906
  if (ctx->compute_ctx.expired()) {
8620
- compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
8907
+ compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
8621
8908
  ctx->compute_ctx = compute_ctx;
8622
8909
  ggml_vk_ctx_begin(ctx->device, compute_ctx);
8623
8910
  } else {
@@ -8651,6 +8938,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8651
8938
  case GGML_OP_RMS_NORM_BACK:
8652
8939
  case GGML_OP_L2_NORM:
8653
8940
  case GGML_OP_UNARY:
8941
+ case GGML_OP_GLU:
8654
8942
  case GGML_OP_DIAG_MASK_INF:
8655
8943
  case GGML_OP_SOFT_MAX:
8656
8944
  case GGML_OP_SOFT_MAX_BACK:
@@ -8663,6 +8951,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8663
8951
  case GGML_OP_COUNT_EQUAL:
8664
8952
  case GGML_OP_IM2COL:
8665
8953
  case GGML_OP_TIMESTEP_EMBEDDING:
8954
+ case GGML_OP_CONV_TRANSPOSE_1D:
8666
8955
  case GGML_OP_POOL_2D:
8667
8956
  case GGML_OP_CONV_2D_DW:
8668
8957
  case GGML_OP_LEAKY_RELU:
@@ -8670,7 +8959,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8670
8959
  // These operations all go through ggml_vk_op_f32, so short-circuit and
8671
8960
  // do the only thing needed for the dryrun.
8672
8961
  vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
8673
- ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
8962
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
8674
8963
  return false;
8675
8964
  }
8676
8965
  default:
@@ -8762,8 +9051,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8762
9051
 
8763
9052
  break;
8764
9053
  case GGML_OP_RMS_NORM:
8765
- ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
8766
-
9054
+ if (ctx->num_additional_fused_ops > 0) {
9055
+ // fused rms_norm + mul
9056
+ ggml_tensor *mul = cgraph->nodes[node_idx + 1];
9057
+ ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
9058
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
9059
+ } else {
9060
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
9061
+ }
8767
9062
  break;
8768
9063
  case GGML_OP_RMS_NORM_BACK:
8769
9064
  ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -8787,6 +9082,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8787
9082
  return false;
8788
9083
  }
8789
9084
  break;
9085
+ case GGML_OP_GLU:
9086
+ switch (ggml_get_glu_op(node)) {
9087
+ case GGML_GLU_OP_GEGLU:
9088
+ case GGML_GLU_OP_REGLU:
9089
+ case GGML_GLU_OP_SWIGLU:
9090
+ ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
9091
+ break;
9092
+ default:
9093
+ return false;
9094
+ }
9095
+ break;
8790
9096
  case GGML_OP_DIAG_MASK_INF:
8791
9097
  ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
8792
9098
 
@@ -8834,6 +9140,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8834
9140
  case GGML_OP_TIMESTEP_EMBEDDING:
8835
9141
  ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
8836
9142
 
9143
+ break;
9144
+ case GGML_OP_CONV_TRANSPOSE_1D:
9145
+ ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun);
9146
+
8837
9147
  break;
8838
9148
  case GGML_OP_POOL_2D:
8839
9149
  ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
@@ -8885,7 +9195,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8885
9195
 
8886
9196
  ctx->tensor_ctxs[node_idx] = compute_ctx;
8887
9197
 
8888
- #if defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_PERF)
9198
+ #if defined(GGML_VULKAN_CHECK_RESULTS)
8889
9199
  // Force context reset on each node so that each tensor ends up in its own context
8890
9200
  // and can be run and compared to its CPU equivalent separately
8891
9201
  last_node = true;
@@ -8908,8 +9218,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8908
9218
  if (!ok) {
8909
9219
  if (node->op == GGML_OP_UNARY) {
8910
9220
  std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
8911
- }
8912
- else {
9221
+ } else if (node->op == GGML_OP_GLU) {
9222
+ std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
9223
+ } else {
8913
9224
  std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
8914
9225
  }
8915
9226
  }
@@ -8962,6 +9273,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
8962
9273
  case GGML_OP_COUNT_EQUAL:
8963
9274
  case GGML_OP_IM2COL:
8964
9275
  case GGML_OP_TIMESTEP_EMBEDDING:
9276
+ case GGML_OP_CONV_TRANSPOSE_1D:
8965
9277
  case GGML_OP_POOL_2D:
8966
9278
  case GGML_OP_CONV_2D_DW:
8967
9279
  case GGML_OP_RWKV_WKV6:
@@ -8987,6 +9299,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
8987
9299
  return false;
8988
9300
  }
8989
9301
  break;
9302
+ case GGML_OP_GLU:
9303
+ switch (ggml_get_glu_op(tensor)) {
9304
+ case GGML_GLU_OP_GEGLU:
9305
+ case GGML_GLU_OP_REGLU:
9306
+ case GGML_GLU_OP_SWIGLU:
9307
+ buf = tensor->buffer;
9308
+ break;
9309
+ default:
9310
+ return false;
9311
+ }
9312
+ break;
8990
9313
  case GGML_OP_MUL_MAT:
8991
9314
  case GGML_OP_MUL_MAT_ID:
8992
9315
  case GGML_OP_FLASH_ATTN_EXT:
@@ -9057,19 +9380,8 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
9057
9380
  }
9058
9381
  ctx->gc.temp_buffers.clear();
9059
9382
 
9060
- for (auto& dsr : ctx->device->pipeline_descriptor_set_requirements) {
9061
- vk_pipeline_ref plr = ctx->device->pipelines[dsr.first];
9062
-
9063
- if (plr.expired()) {
9064
- continue;
9065
- }
9066
-
9067
- vk_pipeline pl = plr.lock();
9068
- ggml_pipeline_cleanup(pl);
9069
- }
9070
-
9071
- ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue);
9072
- ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue);
9383
+ ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
9384
+ ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
9073
9385
 
9074
9386
  for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
9075
9387
  ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
@@ -9090,7 +9402,8 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
9090
9402
 
9091
9403
  ctx->tensor_ctxs.clear();
9092
9404
  ctx->gc.contexts.clear();
9093
- ctx->device->pipeline_descriptor_set_requirements.clear();
9405
+ ctx->pipeline_descriptor_set_requirements = 0;
9406
+ ctx->descriptor_set_idx = 0;
9094
9407
  }
9095
9408
 
9096
9409
  // Clean up on backend free
@@ -9117,6 +9430,15 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
9117
9430
 
9118
9431
  ctx->device->device.destroyFence(ctx->fence);
9119
9432
  ctx->device->device.destroyFence(ctx->almost_ready_fence);
9433
+
9434
+ for (auto& pool : ctx->descriptor_pools) {
9435
+ ctx->device->device.destroyDescriptorPool(pool);
9436
+ }
9437
+ ctx->descriptor_pools.clear();
9438
+ ctx->descriptor_sets.clear();
9439
+
9440
+ ctx->compute_cmd_pool.destroy(ctx->device->device);
9441
+ ctx->transfer_cmd_pool.destroy(ctx->device->device);
9120
9442
  }
9121
9443
 
9122
9444
  static int ggml_vk_get_device_count() {
@@ -9324,6 +9646,12 @@ static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer
9324
9646
  UNUSED(buft);
9325
9647
  }
9326
9648
 
9649
+ static size_t ggml_backend_vk_host_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
9650
+ return vk_instance.devices[0]->suballocation_block_size;
9651
+
9652
+ UNUSED(buft);
9653
+ }
9654
+
9327
9655
  // Should be changed to return device-specific host buffer type
9328
9656
  // but that probably requires changes in llama.cpp
9329
9657
  ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
@@ -9332,7 +9660,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
9332
9660
  /* .get_name = */ ggml_backend_vk_host_buffer_type_name,
9333
9661
  /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer,
9334
9662
  /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment,
9335
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
9663
+ /* .get_max_size = */ ggml_backend_vk_host_buffer_type_get_max_size,
9336
9664
  /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
9337
9665
  /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
9338
9666
  },
@@ -9383,7 +9711,7 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
9383
9711
 
9384
9712
  if (ctx->transfer_ctx.expired()) {
9385
9713
  // Initialize new transfer context
9386
- transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
9714
+ transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
9387
9715
  ctx->transfer_ctx = transfer_ctx;
9388
9716
  ggml_vk_ctx_begin(ctx->device, transfer_ctx);
9389
9717
  } else {
@@ -9406,7 +9734,7 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
9406
9734
 
9407
9735
  if (ctx->transfer_ctx.expired()) {
9408
9736
  // Initialize new transfer context
9409
- transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
9737
+ transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
9410
9738
  ctx->transfer_ctx = transfer_ctx;
9411
9739
  ggml_vk_ctx_begin(ctx->device, transfer_ctx);
9412
9740
  } else {
@@ -9429,7 +9757,7 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_
9429
9757
 
9430
9758
  if (ctx->transfer_ctx.expired()) {
9431
9759
  // Initialize new transfer context
9432
- transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
9760
+ transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
9433
9761
  ctx->transfer_ctx = transfer_ctx;
9434
9762
  ggml_vk_ctx_begin(ctx->device, transfer_ctx);
9435
9763
  } else {
@@ -9479,18 +9807,30 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9479
9807
  VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
9480
9808
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
9481
9809
 
9810
+ if (vk_instance.debug_utils_support) {
9811
+ vk::DebugUtilsLabelEXT dul = {};
9812
+ dul.pLabelName = "ggml_backend_vk_graph_compute";
9813
+ dul.color = std::array<float,4>{1.0f, 1.0f, 1.0f, 1.0f};
9814
+ vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
9815
+ }
9816
+
9482
9817
  uint64_t total_mat_mul_bytes = 0;
9483
9818
  for (int i = 0; i < cgraph->n_nodes; i++) {
9484
- ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
9819
+ if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9820
+ ctx->num_additional_fused_ops = 1;
9821
+ }
9822
+ ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
9485
9823
  if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
9486
9824
  total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
9487
9825
  }
9826
+ i += ctx->num_additional_fused_ops;
9827
+ ctx->num_additional_fused_ops = 0;
9488
9828
  }
9489
9829
  if (ctx->device->need_compiles) {
9490
9830
  ggml_vk_load_shaders(ctx->device);
9491
9831
  }
9492
9832
  ggml_vk_preallocate_buffers(ctx);
9493
- ggml_pipeline_allocate_descriptor_sets(ctx->device);
9833
+ ggml_pipeline_allocate_descriptor_sets(ctx);
9494
9834
 
9495
9835
  int last_node = cgraph->n_nodes - 1;
9496
9836
 
@@ -9505,6 +9845,29 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9505
9845
  bool first_node_in_batch = true; // true if next node will be first node in a batch
9506
9846
  int submit_node_idx = 0; // index to first node in a batch
9507
9847
 
9848
+ vk_context compute_ctx;
9849
+ if (vk_perf_logger_enabled) {
9850
+ // allocate/resize the query pool
9851
+ if (ctx->device->num_queries < cgraph->n_nodes + 1) {
9852
+ if (ctx->device->query_pool) {
9853
+ ctx->device->device.destroyQueryPool(ctx->device->query_pool);
9854
+ }
9855
+ vk::QueryPoolCreateInfo query_create_info;
9856
+ query_create_info.queryType = vk::QueryType::eTimestamp;
9857
+ query_create_info.queryCount = cgraph->n_nodes + 100;
9858
+ ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info);
9859
+ ctx->device->num_queries = query_create_info.queryCount;
9860
+ }
9861
+
9862
+ ctx->device->device.resetQueryPool(ctx->device->query_pool, 0, cgraph->n_nodes+1);
9863
+
9864
+ GGML_ASSERT(ctx->compute_ctx.expired());
9865
+ compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
9866
+ ctx->compute_ctx = compute_ctx;
9867
+ ggml_vk_ctx_begin(ctx->device, compute_ctx);
9868
+ compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
9869
+ }
9870
+
9508
9871
  // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
9509
9872
  // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
9510
9873
  // (and scaled down based on model size, so smaller models submit earlier).
@@ -9523,14 +9886,32 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9523
9886
  mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
9524
9887
  }
9525
9888
 
9889
+ if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
9890
+ ctx->num_additional_fused_ops = 1;
9891
+ }
9892
+
9526
9893
  // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
9527
9894
  bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
9528
9895
  bool submit = (submitted_nodes >= nodes_per_submit) ||
9529
9896
  (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
9530
- (i == last_node) ||
9897
+ (i + ctx->num_additional_fused_ops == last_node) ||
9531
9898
  (almost_ready && !ctx->almost_ready_fence_pending);
9532
9899
 
9533
- bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
9900
+ bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
9901
+
9902
+ if (vk_perf_logger_enabled) {
9903
+ if (ctx->compute_ctx.expired()) {
9904
+ compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
9905
+ ctx->compute_ctx = compute_ctx;
9906
+ ggml_vk_ctx_begin(ctx->device, compute_ctx);
9907
+ } else {
9908
+ compute_ctx = ctx->compute_ctx.lock();
9909
+ }
9910
+ // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
9911
+ for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
9912
+ compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
9913
+ }
9914
+ }
9534
9915
 
9535
9916
  if (enqueued) {
9536
9917
  ++submitted_nodes;
@@ -9551,11 +9932,31 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9551
9932
  }
9552
9933
  submit_count++;
9553
9934
  }
9935
+ i += ctx->num_additional_fused_ops;
9936
+ ctx->num_additional_fused_ops = 0;
9554
9937
  }
9555
9938
 
9556
- #ifdef GGML_VULKAN_PERF
9557
- ctx->device->perf_logger->print_timings();
9558
- #endif
9939
+ if (vk_perf_logger_enabled) {
9940
+ // End the command buffer and submit/wait
9941
+ GGML_ASSERT(!ctx->compute_ctx.expired());
9942
+ compute_ctx = ctx->compute_ctx.lock();
9943
+ ggml_vk_ctx_end(compute_ctx);
9944
+
9945
+ ggml_vk_submit(compute_ctx, ctx->device->fence);
9946
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences");
9947
+ ctx->device->device.resetFences({ ctx->device->fence });
9948
+
9949
+ // Get the results and pass them to the logger
9950
+ std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
9951
+ VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
9952
+ for (int i = 0; i < cgraph->n_nodes; i++) {
9953
+ if (!ggml_vk_is_empty(cgraph->nodes[i])) {
9954
+ ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod));
9955
+ }
9956
+ }
9957
+
9958
+ ctx->device->perf_logger->print_timings();
9959
+ }
9559
9960
 
9560
9961
  ggml_vk_graph_cleanup(ctx);
9561
9962
 
@@ -9707,6 +10108,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9707
10108
  return false;
9708
10109
  }
9709
10110
  break;
10111
+ case GGML_OP_GLU:
10112
+ switch (ggml_get_glu_op(op)) {
10113
+ case GGML_GLU_OP_GEGLU:
10114
+ case GGML_GLU_OP_REGLU:
10115
+ case GGML_GLU_OP_SWIGLU:
10116
+ return ggml_is_contiguous(op->src[0]) &&
10117
+ (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10118
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
10119
+ (op->src[0]->type == op->type);
10120
+ default:
10121
+ return false;
10122
+ }
10123
+ break;
9710
10124
  case GGML_OP_MUL_MAT:
9711
10125
  case GGML_OP_MUL_MAT_ID:
9712
10126
  {
@@ -9971,6 +10385,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9971
10385
  case GGML_OP_LEAKY_RELU:
9972
10386
  case GGML_OP_OPT_STEP_ADAMW:
9973
10387
  return true;
10388
+ case GGML_OP_CONV_TRANSPOSE_1D:
10389
+ return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
9974
10390
  default:
9975
10391
  return false;
9976
10392
  }
@@ -10114,11 +10530,28 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
10114
10530
  UNUSED(instance_extensions);
10115
10531
  }
10116
10532
 
10533
+ // Extension availability
10534
+ static bool ggml_vk_instance_debug_utils_ext_available(
10535
+ const std::vector<vk::ExtensionProperties> & instance_extensions) {
10536
+ // Check for portability enumeration extension for MoltenVK support
10537
+ for (const auto & properties : instance_extensions) {
10538
+ if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) {
10539
+ return true;
10540
+ }
10541
+ }
10542
+
10543
+ std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl;
10544
+ return false;
10545
+
10546
+ UNUSED(instance_extensions);
10547
+ }
10548
+
10117
10549
  static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
10118
10550
  switch (props.vendorID) {
10119
10551
  case VK_VENDOR_ID_INTEL:
10120
- // Intel drivers don't support coopmat properly yet
10121
- return false;
10552
+ // Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost,
10553
+ // while some older hardware (ex. Arc A770) has performance regressions
10554
+ return arch == vk_device_architecture::INTEL_XE2;
10122
10555
  case VK_VENDOR_ID_AMD:
10123
10556
  if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
10124
10557
  // Workaround for AMD proprietary driver reporting support on all GPUs
@@ -10418,6 +10851,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10418
10851
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
10419
10852
  GGML_ABORT("fatal error");
10420
10853
  }
10854
+ } else if (tensor->op == GGML_OP_GLU) {
10855
+ if (src_clone[1] == nullptr) {
10856
+ tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
10857
+ } else {
10858
+ tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
10859
+ }
10421
10860
  } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
10422
10861
  if (src1 == nullptr) {
10423
10862
  tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
@@ -10462,6 +10901,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10462
10901
  const int32_t dim = tensor->op_params[0];
10463
10902
  const int32_t max_period = tensor->op_params[1];
10464
10903
  tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
10904
+ } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){
10905
+ const int32_t s0 = tensor->op_params[0];
10906
+ const int32_t p0 = tensor->op_params[1];
10907
+ const int32_t d0 = tensor->op_params[2];
10908
+ tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
10465
10909
  } else if (tensor->op == GGML_OP_POOL_2D) {
10466
10910
  enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
10467
10911
  const int32_t k0 = tensor->op_params[1];