llama-cpp-pydist 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. llama_cpp/binaries/{llama-b7488-bin-win-cpu-x64.zip → llama-b7631-bin-win-cpu-x64.zip} +0 -0
  2. llama_cpp_pydist-0.21.0.dist-info/METADATA +4684 -0
  3. {llama_cpp_pydist-0.19.0.dist-info → llama_cpp_pydist-0.21.0.dist-info}/RECORD +240 -222
  4. {llama_cpp_pydist-0.19.0.dist-info → llama_cpp_pydist-0.21.0.dist-info}/WHEEL +1 -1
  5. vendor_llama_cpp_pydist/llama.cpp/.devops/cuda-new.Dockerfile +95 -0
  6. vendor_llama_cpp_pydist/llama.cpp/.gemini/settings.json +1 -0
  7. vendor_llama_cpp_pydist/llama.cpp/.github/ISSUE_TEMPLATE/010-bug-compilation.yml +2 -1
  8. vendor_llama_cpp_pydist/llama.cpp/.github/ISSUE_TEMPLATE/011-bug-results.yml +13 -2
  9. vendor_llama_cpp_pydist/llama.cpp/.github/ISSUE_TEMPLATE/019-bug-misc.yml +13 -2
  10. vendor_llama_cpp_pydist/llama.cpp/.github/workflows/build.yml +18 -6
  11. vendor_llama_cpp_pydist/llama.cpp/.github/workflows/docker.yml +25 -13
  12. vendor_llama_cpp_pydist/llama.cpp/.github/workflows/release.yml +9 -5
  13. vendor_llama_cpp_pydist/llama.cpp/.github/workflows/server.yml +18 -0
  14. vendor_llama_cpp_pydist/llama.cpp/AGENTS.md +81 -0
  15. vendor_llama_cpp_pydist/llama.cpp/CLAUDE.md +1 -0
  16. vendor_llama_cpp_pydist/llama.cpp/CONTRIBUTING.md +34 -5
  17. vendor_llama_cpp_pydist/llama.cpp/ci/run.sh +2 -1
  18. vendor_llama_cpp_pydist/llama.cpp/common/CMakeLists.txt +4 -3
  19. vendor_llama_cpp_pydist/llama.cpp/common/arg.cpp +46 -14
  20. vendor_llama_cpp_pydist/llama.cpp/common/arg.h +1 -0
  21. vendor_llama_cpp_pydist/llama.cpp/common/chat-parser.cpp +11 -0
  22. vendor_llama_cpp_pydist/llama.cpp/common/chat.cpp +36 -7
  23. vendor_llama_cpp_pydist/llama.cpp/common/chat.h +1 -0
  24. vendor_llama_cpp_pydist/llama.cpp/common/common.cpp +42 -23
  25. vendor_llama_cpp_pydist/llama.cpp/common/common.h +7 -2
  26. vendor_llama_cpp_pydist/llama.cpp/common/llguidance.cpp +10 -6
  27. vendor_llama_cpp_pydist/llama.cpp/common/regex-partial.cpp +13 -13
  28. vendor_llama_cpp_pydist/llama.cpp/common/sampling.cpp +58 -14
  29. vendor_llama_cpp_pydist/llama.cpp/common/sampling.h +3 -1
  30. vendor_llama_cpp_pydist/llama.cpp/convert_hf_to_gguf.py +424 -103
  31. vendor_llama_cpp_pydist/llama.cpp/convert_hf_to_gguf_update.py +5 -0
  32. vendor_llama_cpp_pydist/llama.cpp/docs/backend/CANN.md +4 -0
  33. vendor_llama_cpp_pydist/llama.cpp/docs/backend/OPENCL.md +51 -1
  34. vendor_llama_cpp_pydist/llama.cpp/docs/backend/SYCL.md +1 -1
  35. vendor_llama_cpp_pydist/llama.cpp/docs/backend/hexagon/README.md +5 -5
  36. vendor_llama_cpp_pydist/llama.cpp/docs/backend/hexagon/developer.md +1 -1
  37. vendor_llama_cpp_pydist/llama.cpp/docs/build.md +21 -2
  38. vendor_llama_cpp_pydist/llama.cpp/docs/development/parsing.md +2 -2
  39. vendor_llama_cpp_pydist/llama.cpp/docs/ops/Metal.csv +360 -322
  40. vendor_llama_cpp_pydist/llama.cpp/docs/ops.md +1 -1
  41. vendor_llama_cpp_pydist/llama.cpp/ggml/CMakeLists.txt +13 -1
  42. vendor_llama_cpp_pydist/llama.cpp/ggml/include/ggml-backend.h +1 -1
  43. vendor_llama_cpp_pydist/llama.cpp/ggml/src/CMakeLists.txt +23 -9
  44. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-backend.cpp +11 -11
  45. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +303 -19
  46. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +17 -0
  47. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cann/common.h +153 -9
  48. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +51 -158
  49. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +12 -2
  50. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +1 -1
  51. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +86 -25
  52. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +15 -8
  53. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +768 -0
  54. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +0 -4
  55. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +66 -1
  56. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/argsort.cu +48 -27
  57. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/argsort.cuh +16 -0
  58. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/common.cuh +45 -9
  59. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/cpy.cu +117 -103
  60. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/cumsum.cu +105 -35
  61. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +3 -1
  62. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +2 -2
  63. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +83 -33
  64. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/mean.cu +3 -0
  65. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/mma.cuh +21 -0
  66. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/mmq.cu +34 -8
  67. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +168 -13
  68. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/quantize.cu +151 -0
  69. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/quantize.cuh +14 -0
  70. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/softmax.cu +203 -6
  71. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/top-k.cu +96 -0
  72. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/top-k.cuh +3 -0
  73. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu +17 -2
  74. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh +6 -1
  75. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  76. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
  77. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  78. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +224 -758
  79. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp/act-ops.c +316 -164
  80. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp/htp-dma.c +5 -11
  81. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp/htp-dma.h +46 -15
  82. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +9 -3
  83. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp/main.c +2 -1
  84. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +20 -20
  85. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/htp-utils.h +1 -0
  86. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-hexagon/op-desc.h +153 -0
  87. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-impl.h +0 -4
  88. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +57 -0
  89. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h +2 -0
  90. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m +5 -0
  91. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +20 -0
  92. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +71 -2
  93. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h +1 -0
  94. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +73 -6
  95. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +134 -13
  96. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-opencl/kernels/cvt.cl +21 -0
  97. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  98. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +14 -7
  99. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +42 -1
  100. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +742 -315
  101. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  102. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +28 -14
  103. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  104. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  105. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +1 -7
  106. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl +2 -0
  107. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +17 -4
  108. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +42 -24
  109. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +11 -0
  110. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +115 -0
  111. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +10 -4
  112. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +29 -18
  113. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +19 -16
  114. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  115. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +10 -4
  116. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +8 -8
  117. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +11 -4
  118. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +4 -1
  119. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +4 -1
  120. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +4 -1
  121. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +1 -0
  122. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +4 -1
  123. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +57 -22
  124. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +312 -6
  125. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +54 -0
  126. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -2
  127. vendor_llama_cpp_pydist/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  128. vendor_llama_cpp_pydist/llama.cpp/gguf-py/gguf/constants.py +99 -0
  129. vendor_llama_cpp_pydist/llama.cpp/gguf-py/gguf/gguf_writer.py +38 -2
  130. vendor_llama_cpp_pydist/llama.cpp/gguf-py/gguf/tensor_mapping.py +26 -0
  131. vendor_llama_cpp_pydist/llama.cpp/gguf-py/gguf/utility.py +0 -8
  132. vendor_llama_cpp_pydist/llama.cpp/grammars/README.md +3 -0
  133. vendor_llama_cpp_pydist/llama.cpp/include/llama.h +99 -12
  134. vendor_llama_cpp_pydist/llama.cpp/scripts/snapdragon/adb/run-cli.sh +9 -9
  135. vendor_llama_cpp_pydist/llama.cpp/scripts/snapdragon/adb/run-completion.sh +53 -0
  136. vendor_llama_cpp_pydist/llama.cpp/scripts/sync-ggml.last +1 -1
  137. vendor_llama_cpp_pydist/llama.cpp/src/CMakeLists.txt +4 -0
  138. vendor_llama_cpp_pydist/llama.cpp/src/llama-adapter.cpp +12 -3
  139. vendor_llama_cpp_pydist/llama.cpp/src/llama-adapter.h +7 -1
  140. vendor_llama_cpp_pydist/llama.cpp/src/llama-arch.cpp +76 -0
  141. vendor_llama_cpp_pydist/llama.cpp/src/llama-arch.h +7 -0
  142. vendor_llama_cpp_pydist/llama.cpp/src/llama-chat.cpp +11 -0
  143. vendor_llama_cpp_pydist/llama.cpp/src/llama-chat.h +1 -0
  144. vendor_llama_cpp_pydist/llama.cpp/src/llama-context.cpp +625 -40
  145. vendor_llama_cpp_pydist/llama.cpp/src/llama-context.h +43 -1
  146. vendor_llama_cpp_pydist/llama.cpp/src/llama-grammar.cpp +40 -13
  147. vendor_llama_cpp_pydist/llama.cpp/src/llama-grammar.h +2 -0
  148. vendor_llama_cpp_pydist/llama.cpp/src/llama-graph.cpp +166 -2
  149. vendor_llama_cpp_pydist/llama.cpp/src/llama-graph.h +71 -6
  150. vendor_llama_cpp_pydist/llama.cpp/src/llama-hparams.h +6 -5
  151. vendor_llama_cpp_pydist/llama.cpp/src/llama-kv-cache.h +1 -1
  152. vendor_llama_cpp_pydist/llama.cpp/src/llama-mmap.cpp +11 -4
  153. vendor_llama_cpp_pydist/llama.cpp/src/llama-model-loader.cpp +23 -0
  154. vendor_llama_cpp_pydist/llama.cpp/src/llama-model-loader.h +2 -0
  155. vendor_llama_cpp_pydist/llama.cpp/src/llama-model.cpp +329 -26
  156. vendor_llama_cpp_pydist/llama.cpp/src/llama-model.h +13 -2
  157. vendor_llama_cpp_pydist/llama.cpp/src/llama-sampling.cpp +1259 -186
  158. vendor_llama_cpp_pydist/llama.cpp/src/llama-sampling.h +19 -7
  159. vendor_llama_cpp_pydist/llama.cpp/src/llama-vocab.cpp +101 -33
  160. vendor_llama_cpp_pydist/llama.cpp/src/llama-vocab.h +2 -0
  161. vendor_llama_cpp_pydist/llama.cpp/src/llama.cpp +53 -38
  162. vendor_llama_cpp_pydist/llama.cpp/src/models/afmoe.cpp +9 -5
  163. vendor_llama_cpp_pydist/llama.cpp/src/models/bert.cpp +4 -2
  164. vendor_llama_cpp_pydist/llama.cpp/src/models/cogvlm.cpp +5 -3
  165. vendor_llama_cpp_pydist/llama.cpp/src/models/cohere2-iswa.cpp +3 -0
  166. vendor_llama_cpp_pydist/llama.cpp/src/models/deepseek2.cpp +1 -1
  167. vendor_llama_cpp_pydist/llama.cpp/src/models/gemma-embedding.cpp +2 -6
  168. vendor_llama_cpp_pydist/llama.cpp/src/models/gemma2-iswa.cpp +5 -2
  169. vendor_llama_cpp_pydist/llama.cpp/src/models/gemma3.cpp +3 -4
  170. vendor_llama_cpp_pydist/llama.cpp/src/models/gemma3n-iswa.cpp +4 -7
  171. vendor_llama_cpp_pydist/llama.cpp/src/models/llama-iswa.cpp +6 -2
  172. vendor_llama_cpp_pydist/llama.cpp/src/models/llama.cpp +19 -6
  173. vendor_llama_cpp_pydist/llama.cpp/src/models/maincoder.cpp +117 -0
  174. vendor_llama_cpp_pydist/llama.cpp/src/models/mimo2-iswa.cpp +123 -0
  175. vendor_llama_cpp_pydist/llama.cpp/src/models/models.h +18 -0
  176. vendor_llama_cpp_pydist/llama.cpp/src/models/modern-bert.cpp +116 -0
  177. vendor_llama_cpp_pydist/llama.cpp/src/models/openai-moe-iswa.cpp +5 -2
  178. vendor_llama_cpp_pydist/llama.cpp/src/models/plamo3.cpp +128 -0
  179. vendor_llama_cpp_pydist/llama.cpp/src/models/smallthinker.cpp +11 -5
  180. vendor_llama_cpp_pydist/llama.cpp/src/unicode.cpp +23 -14
  181. vendor_llama_cpp_pydist/llama.cpp/tests/CMakeLists.txt +12 -2
  182. vendor_llama_cpp_pydist/llama.cpp/tests/test-backend-ops.cpp +286 -65
  183. vendor_llama_cpp_pydist/llama.cpp/tests/test-backend-sampler.cpp +1237 -0
  184. vendor_llama_cpp_pydist/llama.cpp/tests/test-chat.cpp +29 -3
  185. vendor_llama_cpp_pydist/llama.cpp/tests/test-grammar-llguidance.cpp +3 -0
  186. vendor_llama_cpp_pydist/llama.cpp/tests/test-regex-partial.cpp +14 -14
  187. vendor_llama_cpp_pydist/llama.cpp/tests/test-tokenizer-0.cpp +1 -1
  188. vendor_llama_cpp_pydist/llama.cpp/tests/test-tokenizer-1-bpe.cpp +1 -1
  189. vendor_llama_cpp_pydist/llama.cpp/tests/test-tokenizer-1-spm.cpp +1 -1
  190. vendor_llama_cpp_pydist/llama.cpp/tools/batched-bench/batched-bench.cpp +11 -0
  191. vendor_llama_cpp_pydist/llama.cpp/tools/cli/README.md +187 -1
  192. vendor_llama_cpp_pydist/llama.cpp/tools/cli/cli.cpp +1 -3
  193. vendor_llama_cpp_pydist/llama.cpp/tools/completion/README.md +179 -7
  194. vendor_llama_cpp_pydist/llama.cpp/tools/completion/completion.cpp +4 -1
  195. vendor_llama_cpp_pydist/llama.cpp/tools/fit-params/fit-params.cpp +3 -3
  196. vendor_llama_cpp_pydist/llama.cpp/tools/llama-bench/llama-bench.cpp +18 -1
  197. vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/CMakeLists.txt +1 -0
  198. vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/clip-impl.h +12 -7
  199. vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/clip-model.h +3 -1
  200. vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/clip.cpp +118 -4
  201. vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/models/models.h +10 -0
  202. vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/models/siglip.cpp +9 -4
  203. vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/models/whisper-enc.cpp +9 -0
  204. vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/models/youtuvl.cpp +179 -0
  205. vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/mtmd.cpp +5 -1
  206. vendor_llama_cpp_pydist/llama.cpp/tools/mtmd/mtmd.h +3 -0
  207. vendor_llama_cpp_pydist/llama.cpp/tools/quantize/quantize.cpp +6 -0
  208. vendor_llama_cpp_pydist/llama.cpp/tools/server/CMakeLists.txt +0 -8
  209. vendor_llama_cpp_pydist/llama.cpp/tools/server/README-dev.md +2 -0
  210. vendor_llama_cpp_pydist/llama.cpp/tools/server/README.md +27 -14
  211. vendor_llama_cpp_pydist/llama.cpp/tools/server/public/index.html.gz +0 -0
  212. vendor_llama_cpp_pydist/llama.cpp/tools/server/server-common.cpp +22 -24
  213. vendor_llama_cpp_pydist/llama.cpp/tools/server/server-common.h +2 -3
  214. vendor_llama_cpp_pydist/llama.cpp/tools/server/server-context.cpp +453 -267
  215. vendor_llama_cpp_pydist/llama.cpp/tools/server/server-context.h +52 -15
  216. vendor_llama_cpp_pydist/llama.cpp/tools/server/server-http.cpp +16 -10
  217. vendor_llama_cpp_pydist/llama.cpp/tools/server/server-models.cpp +174 -62
  218. vendor_llama_cpp_pydist/llama.cpp/tools/server/server-models.h +14 -5
  219. vendor_llama_cpp_pydist/llama.cpp/tools/server/server-queue.cpp +78 -21
  220. vendor_llama_cpp_pydist/llama.cpp/tools/server/server-queue.h +48 -10
  221. vendor_llama_cpp_pydist/llama.cpp/tools/server/server-task.cpp +36 -11
  222. vendor_llama_cpp_pydist/llama.cpp/tools/server/server-task.h +28 -35
  223. vendor_llama_cpp_pydist/llama.cpp/tools/server/server.cpp +9 -5
  224. vendor_llama_cpp_pydist/llama.cpp/tools/server/tests/unit/test_chat_completion.py +11 -2
  225. vendor_llama_cpp_pydist/llama.cpp/tools/server/tests/unit/test_sleep.py +39 -0
  226. vendor_llama_cpp_pydist/llama.cpp/tools/server/tests/utils.py +3 -0
  227. vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte +25 -1
  228. vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte +66 -13
  229. vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte +5 -0
  230. vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/constants/settings-config.ts +3 -0
  231. vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts +125 -11
  232. vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/services/chat.ts +15 -8
  233. vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/stores/chat.svelte.ts +12 -3
  234. vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/stores/settings.svelte.ts +4 -5
  235. vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/types/api.d.ts +5 -0
  236. vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/types/settings.d.ts +2 -1
  237. vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/lib/utils/clipboard.ts +1 -4
  238. vendor_llama_cpp_pydist/llama.cpp/tools/server/webui/src/routes/+layout.svelte +1 -1
  239. llama_cpp_pydist-0.19.0.dist-info/METADATA +0 -2506
  240. vendor_llama_cpp_pydist/llama.cpp/.github/copilot-instructions.md +0 -262
  241. {llama_cpp_pydist-0.19.0.dist-info/licenses → llama_cpp_pydist-0.21.0.dist-info}/LICENSE +0 -0
  242. {llama_cpp_pydist-0.19.0.dist-info → llama_cpp_pydist-0.21.0.dist-info}/top_level.txt +0 -0
@@ -379,18 +379,18 @@ enum FaCodePath {
379
379
  };
380
380
 
381
381
  struct vk_fa_pipeline_state {
382
- vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc)
383
- : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {}
382
+ vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
383
+ : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
384
384
 
385
385
  uint32_t HSK, HSV;
386
- bool small_rows;
386
+ bool small_rows, small_cache;
387
387
  FaCodePath path;
388
388
  bool aligned;
389
389
  bool f32acc;
390
390
 
391
391
  bool operator<(const vk_fa_pipeline_state &b) const {
392
- return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) <
393
- std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc);
392
+ return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) <
393
+ std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc);
394
394
  }
395
395
  };
396
396
 
@@ -434,8 +434,15 @@ static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGM
434
434
  GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
435
435
  GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
436
436
  GGML_OP_RESHAPE };
437
+
438
+ static constexpr std::initializer_list<ggml_op> topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD,
439
+ GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
440
+ GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
441
+ GGML_OP_DIV, GGML_OP_RESHAPE };
442
+
437
443
  static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
438
444
  GGML_OP_VIEW, GGML_OP_GET_ROWS };
445
+
439
446
  static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
440
447
  GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
441
448
  GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
@@ -464,6 +471,32 @@ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softma
464
471
  { 9, 0, 8 }, // reshape->src[0] == div
465
472
  };
466
473
 
474
+ //node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ]
475
+ //node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ]
476
+ //node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ]
477
+ //node #439 ( ARGSORT): ffn_moe_argsort-10 ( 256K) [Vulka ] use=1: ffn_moe_probs_biased ( 256K) [Vulka ]
478
+ //node #440 ( VIEW): ffn_moe_topk-10 ( 255K) [Vulka ] use=3: ffn_moe_argsort-10 ( 256K) [Vulka ]
479
+ //node #441 ( GET_ROWS): ffn_moe_weights-10 ( 12K) [Vulka ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka ] ffn_moe_topk-10 ( 255K) [Vulka ]
480
+ //node #442 ( RESHAPE): ffn_moe_weights-10 ( ( 12K) [Vulka ] use=2: ffn_moe_weights-10 ( 12K) [Vulka ]
481
+ //node #443 ( SUM_ROWS): ffn_moe_weights_sum- ( 2K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ]
482
+ //node #444 ( CLAMP): ffn_moe_weights_sum_ ( 2K) [Vulka ] use=1: ffn_moe_weights_sum- ( 2K) [Vulka ]
483
+ //node #445 ( DIV): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] ffn_moe_weights_sum_ ( 2K) [Vulka ]
484
+ //node #446 ( RESHAPE): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights_norm ( 12K) [Vulka ]
485
+ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_sigmoid_norm_bias_edges {
486
+ { 1, 0, 0 }, // reshape->src[0] == sigmoid
487
+ { 2, 0, 0 }, // add->src[0] == sigmoid
488
+ { 3, 0, 2 }, // argsort->src[0] == add
489
+ { 4, 0, 3 }, // view->src[0] == argsort
490
+ { 5, 0, 1 }, // get_rows->src[0] == reshape
491
+ { 5, 1, 4 }, // get_rows->src[1] == view
492
+ { 6, 0, 5 }, // reshape->src[0] == get_rows
493
+ { 7, 0, 6 }, // sum_rows->src[0] == reshape
494
+ { 8, 0, 7 }, // clamp->src[0] == sum_rows
495
+ { 9, 0, 6 }, // div->src[0] == reshape
496
+ { 9, 1, 8 }, // div->src[1] == clamp
497
+ {10, 0, 9 }, // reshape->src[0] == div
498
+ };
499
+
467
500
  // same as early_softmax_norm but ending after the get_rows
468
501
  static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
469
502
  { 1, 0, 0 }, // reshape->src[0] == softmax
@@ -491,16 +524,10 @@ enum topk_moe_mode {
491
524
  TOPK_MOE_EARLY_SOFTMAX,
492
525
  TOPK_MOE_EARLY_SOFTMAX_NORM,
493
526
  TOPK_MOE_LATE_SOFTMAX,
527
+ TOPK_MOE_SIGMOID_NORM_BIAS,
494
528
  TOPK_MOE_COUNT,
495
529
  };
496
530
 
497
- static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
498
- topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
499
- num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
500
- TOPK_MOE_LATE_SOFTMAX;
501
- return mode;
502
- }
503
-
504
531
  static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
505
532
  { 1, 0, 0 }, // view->src[0] == rope
506
533
  { 2, 0, 1 }, // set_rows->src[0] == view
@@ -651,7 +678,7 @@ struct vk_device_struct {
651
678
  vk_pipeline pipeline_add_id_f32;
652
679
 
653
680
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
654
- vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32;
681
+ vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
655
682
  vk_pipeline pipeline_scale_f32;
656
683
  vk_pipeline pipeline_sqr_f32;
657
684
  vk_pipeline pipeline_sqrt_f32;
@@ -689,6 +716,7 @@ struct vk_device_struct {
689
716
  vk_pipeline pipeline_gelu_quick[2];
690
717
  vk_pipeline pipeline_silu[2];
691
718
  vk_pipeline pipeline_relu[2];
719
+ vk_pipeline pipeline_xielu[2];
692
720
  vk_pipeline pipeline_neg[2];
693
721
  vk_pipeline pipeline_tanh[2];
694
722
  vk_pipeline pipeline_sigmoid[2];
@@ -730,13 +758,16 @@ struct vk_device_struct {
730
758
 
731
759
  vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
732
760
  vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
733
- vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
761
+ vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16, pipeline_rope_multi_f32_f16;
734
762
  vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
735
763
  vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
736
764
  vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
737
765
  vk_pipeline pipeline_topk_f32[num_topk_pipelines];
738
766
  vk_pipeline pipeline_sum_rows_f32;
739
767
  vk_pipeline pipeline_cumsum_f32;
768
+ vk_pipeline pipeline_cumsum_small_f32;
769
+ vk_pipeline pipeline_cumsum_multipass1_f32;
770
+ vk_pipeline pipeline_cumsum_multipass2_f32;
740
771
  vk_pipeline pipeline_argmax_f32;
741
772
  vk_pipeline pipeline_count_equal_i32;
742
773
  std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
@@ -762,9 +793,10 @@ struct vk_device_struct {
762
793
  std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
763
794
 
764
795
  vk_pipeline pipeline_flash_attn_split_k_reduce;
796
+ vk_pipeline pipeline_count_experts;
765
797
 
766
798
  // [2] is for whether to take n_experts from spec constant (0) or push constant (1)
767
- vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
799
+ vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
768
800
 
769
801
  std::vector<vk_pipeline_ref> all_pipelines;
770
802
 
@@ -855,6 +887,15 @@ struct vk_subbuffer {
855
887
  }
856
888
  };
857
889
 
890
+ // vk_event is used for the event-related backend interfaces. It uses 'event' for
891
+ // event_wait and 'fence' for event_synchronize. Polling on an event for
892
+ // event_synchronize wouldn't be sufficient to wait for command buffers to complete,
893
+ // and would lead to validation errors.
894
+ struct vk_event {
895
+ vk::Event event;
896
+ vk::Fence fence;
897
+ };
898
+
858
899
  struct vk_semaphore {
859
900
  vk::Semaphore s;
860
901
  uint64_t value;
@@ -990,6 +1031,16 @@ struct vk_op_push_constants {
990
1031
  uint32_t KY;
991
1032
  float param1;
992
1033
  float param2;
1034
+ float param3;
1035
+ float param4;
1036
+ };
1037
+
1038
+ struct vk_op_count_experts_push_constants {
1039
+ uint32_t ne00;
1040
+ uint32_t ne01;
1041
+ uint32_t nb00;
1042
+ uint32_t nb01;
1043
+ uint32_t a_offset;
993
1044
  };
994
1045
 
995
1046
  struct vk_op_glu_push_constants {
@@ -1160,6 +1211,11 @@ struct vk_op_topk_moe_push_constants {
1160
1211
  uint32_t n_expert_used;
1161
1212
  float clamp_min;
1162
1213
  float clamp_max;
1214
+ uint32_t gating_func;
1215
+ uint32_t has_bias;
1216
+ uint32_t with_norm;
1217
+ float output_scale;
1218
+ float output_bias;
1163
1219
  };
1164
1220
 
1165
1221
  struct vk_op_add_id_push_constants {
@@ -1180,6 +1236,7 @@ struct vk_op_diag_mask_push_constants {
1180
1236
  struct vk_op_rope_push_constants {
1181
1237
  uint32_t rope_mode;
1182
1238
  uint32_t ncols;
1239
+ uint32_t nrows;
1183
1240
  uint32_t n_dims;
1184
1241
  float freq_scale;
1185
1242
  uint32_t p_delta_rows;
@@ -1258,6 +1315,7 @@ struct vk_op_im2col_push_constants {
1258
1315
  int32_t s0; int32_t s1;
1259
1316
  int32_t p0; int32_t p1;
1260
1317
  int32_t d0; int32_t d1;
1318
+ uint32_t batch_IC;
1261
1319
  };
1262
1320
 
1263
1321
  struct vk_op_im2col_3d_push_constants {
@@ -1551,7 +1609,7 @@ class vk_perf_logger {
1551
1609
  total_op_times += time;
1552
1610
  }
1553
1611
  std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
1554
- << " us";
1612
+ << " us = " << (total_op_times / 1000.0) << " us";
1555
1613
 
1556
1614
  // If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
1557
1615
  auto it = flops.find(t.first);
@@ -1748,6 +1806,8 @@ struct ggml_backend_vk_context {
1748
1806
  // Bit 'i' means nodes[start_of_fusion + i] writes to memory.
1749
1807
  // If there's no fusion, bit 0 is still set.
1750
1808
  int fused_ops_write_mask {};
1809
+ topk_moe_mode fused_topk_moe_mode {};
1810
+ bool fused_topk_moe_scale {};
1751
1811
 
1752
1812
  // for GGML_VK_PERF_LOGGER
1753
1813
  std::unique_ptr<vk_perf_logger> perf_logger;
@@ -2540,6 +2600,15 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
2540
2600
  );
2541
2601
  }
2542
2602
 
2603
+ static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
2604
+ VK_LOG_DEBUG("ggml_vk_set_event()");
2605
+
2606
+ ctx->s->buffer.setEvent(
2607
+ event,
2608
+ ctx->p->q->stage_flags
2609
+ );
2610
+ }
2611
+
2543
2612
  static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) {
2544
2613
  VK_LOG_DEBUG("ggml_vk_wait_events()");
2545
2614
  if (events.empty()) {
@@ -2560,10 +2629,10 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
2560
2629
  static constexpr uint32_t flash_attention_num_small_rows = 32;
2561
2630
  static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
2562
2631
 
2563
- static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
2632
+ static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
2564
2633
  if (hsv >= 192) {
2565
2634
  return 2;
2566
- } else if ((hsv | hsk) & 8) {
2635
+ } else if ((hsv | hsk) & 8 || small_cache) {
2567
2636
  return 4;
2568
2637
  } else {
2569
2638
  return 8;
@@ -2585,9 +2654,8 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
2585
2654
  }
2586
2655
  }
2587
2656
 
2588
- static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
2657
+ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
2589
2658
  GGML_UNUSED(clamp);
2590
- GGML_UNUSED(hsv);
2591
2659
 
2592
2660
  if (path == FA_SCALAR) {
2593
2661
  if (small_rows) {
@@ -2596,9 +2664,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
2596
2664
  if ((hsv | hsk) & 8) {
2597
2665
  // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
2598
2666
  // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
2599
- return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
2667
+ return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
2600
2668
  } else {
2601
- return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
2669
+ return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
2602
2670
  }
2603
2671
  }
2604
2672
  }
@@ -2627,8 +2695,8 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
2627
2695
  return {64, 64};
2628
2696
  }
2629
2697
 
2630
- static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
2631
- return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
2698
+ static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
2699
+ return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
2632
2700
  }
2633
2701
 
2634
2702
  static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
@@ -2637,7 +2705,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
2637
2705
  switch (src0_type) {
2638
2706
  case GGML_TYPE_IQ1_S:
2639
2707
  case GGML_TYPE_IQ1_M:
2640
- lut_size = 2*2048;
2708
+ lut_size = 2*2048 + 4*2048;
2641
2709
  break;
2642
2710
  case GGML_TYPE_IQ2_XXS:
2643
2711
  lut_size = 8*256;
@@ -2808,9 +2876,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
2808
2876
  s_mmq_wg_denoms_k = { 32, 64, 1 };
2809
2877
 
2810
2878
  // spec constants and tile sizes for quant matmul_id
2811
- l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size };
2812
- m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
2813
- s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
2879
+ l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size };
2880
+ m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
2881
+ s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
2814
2882
  l_mmqid_wg_denoms = { 128, 128, 1 };
2815
2883
  m_mmqid_wg_denoms = { 128, 64, 1 };
2816
2884
  s_mmqid_wg_denoms = { 128, 64, 1 };
@@ -2830,39 +2898,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
2830
2898
  const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
2831
2899
  const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
2832
2900
 
2833
- l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
2834
- m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
2835
- s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
2901
+ const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
2902
+
2903
+ l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
2904
+ m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
2905
+ s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
2836
2906
 
2837
- l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
2838
- m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
2839
- s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
2907
+ l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
2908
+ m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
2909
+ s_warptile_mmq = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
2840
2910
 
2841
2911
  // Integer MMQ has a smaller shared memory profile, but heavier register use
2842
- l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
2843
- m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
2844
- s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
2912
+ l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
2913
+ m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
2914
+ s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, subgroup_size_8 };
2845
2915
 
2846
2916
  // K-quants use even more registers, mitigate by setting WMITER to 1
2847
- l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
2848
- m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
2849
- s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
2917
+ l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
2918
+ m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
2919
+ s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, subgroup_size_8 };
2850
2920
 
2851
- l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
2852
- m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
2853
- s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
2921
+ l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
2922
+ m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
2923
+ s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
2854
2924
 
2855
- l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
2856
- m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
2857
- s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
2925
+ l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
2926
+ m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
2927
+ s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
2858
2928
 
2859
- l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
2860
- m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
2861
- s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
2929
+ l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
2930
+ m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
2931
+ s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
2862
2932
 
2863
- l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
2864
- m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
2865
- s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
2933
+ l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
2934
+ m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
2935
+ s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
2866
2936
 
2867
2937
  // chip specific tuning
2868
2938
  if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
@@ -2970,11 +3040,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
2970
3040
  align, disable_robustness, require_full_subgroups, required_subgroup_size);
2971
3041
  };
2972
3042
 
2973
- auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
2974
- return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
3043
+ auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> {
3044
+ return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
2975
3045
  };
2976
3046
 
2977
- auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
3047
+ auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector<uint32_t> {
2978
3048
  // For large number of rows, 128 invocations seems to work best.
2979
3049
  // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
2980
3050
  // can't use 256 for D==80.
@@ -2984,7 +3054,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2984
3054
  uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
2985
3055
  ? scalar_flash_attention_workgroup_size
2986
3056
  : ((small_rows && (D % 32) == 0) ? 256 : 128);
2987
- auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
3057
+ auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
2988
3058
 
2989
3059
  // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
2990
3060
  // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@@ -2999,21 +3069,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
2999
3069
  uint32_t HSK = fa.first.HSK; \
3000
3070
  uint32_t HSV = fa.first.HSV; \
3001
3071
  bool small_rows = fa.first.small_rows; \
3072
+ bool small_cache = fa.first.small_cache; \
3002
3073
  FaCodePath path = fa.first.path; \
3003
3074
  bool aligned = fa.first.aligned; \
3004
3075
  bool f32acc = fa.first.f32acc; \
3005
3076
  if (path == FAPATH) { \
3006
3077
  if (aligned) { \
3007
3078
  if (f32acc) { \
3008
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3079
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3009
3080
  } else { \
3010
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3081
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3011
3082
  } \
3012
3083
  } else { \
3013
3084
  if (f32acc) { \
3014
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3085
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3015
3086
  } else { \
3016
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3087
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3017
3088
  } \
3018
3089
  } \
3019
3090
  } \
@@ -3045,17 +3116,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
3045
3116
  #endif
3046
3117
  #undef CREATE_FA
3047
3118
 
3119
+ const int mul_mat_id_param_count = 5;
3120
+
3048
3121
  #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
3049
3122
  if (device->coopmat2) {
3050
3123
 
3051
3124
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
3052
3125
  #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
3053
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
3054
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
3055
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
3056
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
3057
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
3058
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
3126
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
3127
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
3128
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
3129
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
3130
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
3131
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
3059
3132
 
3060
3133
  // Create 2 variants, {f16,f32} accumulator
3061
3134
  #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
@@ -3091,32 +3164,32 @@ static void ggml_vk_load_shaders(vk_device& device) {
3091
3164
 
3092
3165
  GGML_ASSERT(device->subgroup_ballot);
3093
3166
 
3094
- CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
3167
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
3095
3168
  #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3096
3169
  if (device->coopmat_bf16_support) {
3097
- CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
3170
+ CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
3098
3171
  }
3099
3172
  #endif
3100
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3101
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3102
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3103
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3104
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3105
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3106
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3107
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3108
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3109
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3110
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3111
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3112
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3113
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3114
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3115
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3116
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3117
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3118
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3119
- CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
3173
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3174
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3175
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3176
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3177
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3178
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3179
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3180
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3181
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3182
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3183
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3184
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3185
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3186
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3187
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3188
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3189
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3190
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3191
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3192
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3120
3193
  #undef CREATE_MM
3121
3194
  #undef CREATE_MM2
3122
3195
  } else
@@ -3205,35 +3278,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
3205
3278
 
3206
3279
  GGML_ASSERT(device->subgroup_ballot);
3207
3280
 
3208
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
3209
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
3210
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
3281
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
3282
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
3283
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
3211
3284
  #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3212
3285
  if (device->coopmat_bf16_support) {
3213
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
3286
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
3214
3287
  }
3215
3288
  #endif
3216
3289
 
3217
- CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3218
- CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3219
- CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3220
- CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3221
- CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3222
- CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3223
- CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3224
- CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3225
- CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3226
- CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3227
- CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3228
- CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3229
- CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3230
- CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3231
- CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3232
- CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3233
- CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3234
- CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3235
- CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3236
- CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
3290
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3291
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3292
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3293
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3294
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3295
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3296
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3297
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3298
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3299
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3300
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3301
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3302
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3303
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3304
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3305
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3306
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3307
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3308
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3309
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3237
3310
  #undef CREATE_MM2
3238
3311
  #undef CREATE_MM
3239
3312
  } else
@@ -3318,91 +3391,91 @@ static void ggml_vk_load_shaders(vk_device& device) {
3318
3391
  #endif
3319
3392
 
3320
3393
  if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
3321
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
3322
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
3323
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
3324
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3325
-
3326
- CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3327
- CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3328
- CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3329
- CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3330
- CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3331
- CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3332
- CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3333
- CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3334
- CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3335
- CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3336
- CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3337
- CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3338
- CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3339
- CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3340
- CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3341
- CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3342
- CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3343
- CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3344
- CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3345
- CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3394
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3395
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3396
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3397
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3398
+
3399
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3400
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3401
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3402
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3403
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3404
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3405
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3406
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3407
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3408
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3409
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3410
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3411
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3412
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3413
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3414
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3415
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3416
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3417
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3418
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3346
3419
 
3347
3420
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3348
3421
  if (device->integer_dot_product) {
3349
- CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3350
- CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3351
- CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3352
- CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3353
- CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3422
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3423
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3424
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3425
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3426
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3354
3427
 
3355
- CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3428
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3356
3429
 
3357
- CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3358
- CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3359
- CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3360
- CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3361
- CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3430
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3431
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3432
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3433
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3434
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3362
3435
  }
3363
3436
  #endif
3364
3437
  } else {
3365
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
3366
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
3367
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
3368
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
3369
-
3370
- CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3371
- CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3372
- CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3373
- CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3374
- CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3375
- CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3376
- CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3377
- CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3378
- CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3379
- CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3380
- CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3381
- CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3382
- CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3383
- CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3384
- CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3385
- CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3386
- CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3387
- CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3388
- CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3389
- CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3438
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
3439
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
3440
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
3441
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3442
+
3443
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3444
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3445
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3446
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3447
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3448
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3449
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3450
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3451
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3452
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3453
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3454
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3455
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3456
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3457
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3458
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3459
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3460
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3461
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3462
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3390
3463
 
3391
3464
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3392
3465
  if (device->integer_dot_product) {
3393
- CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
3394
- CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
3395
- CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
3396
- CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
3397
- CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
3466
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3467
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3468
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3469
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3470
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3398
3471
 
3399
- CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
3472
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3400
3473
 
3401
- CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
3402
- CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
3403
- CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
3404
- CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
3405
- CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
3474
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3475
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3476
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3477
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3478
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3406
3479
  }
3407
3480
  #endif
3408
3481
  }
@@ -3479,57 +3552,57 @@ static void ggml_vk_load_shaders(vk_device& device) {
3479
3552
  #endif
3480
3553
 
3481
3554
  if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
3482
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
3483
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
3484
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
3485
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
3486
-
3487
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3488
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3489
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3490
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3491
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3492
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3493
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3494
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3495
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3496
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3497
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3498
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3499
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3500
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3501
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3502
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3503
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3504
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3505
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3506
- CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
3555
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3556
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3557
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3558
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3559
+
3560
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3561
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3562
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3563
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3564
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3565
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3566
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3567
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3568
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3569
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3570
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3571
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3572
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3573
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3574
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3575
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3576
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3577
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3578
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3579
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3507
3580
  } else {
3508
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
3509
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
3510
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
3511
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
3512
-
3513
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3514
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3515
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3516
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3517
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3518
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3519
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3520
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3521
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3522
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3523
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3524
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3525
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3526
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3527
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3528
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3529
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3530
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3531
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3532
- CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
3581
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
3582
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
3583
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
3584
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3585
+
3586
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3587
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3588
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3589
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3590
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3591
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3592
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3593
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3594
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3595
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3596
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3597
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3598
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3599
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3600
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3601
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3602
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3603
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3604
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3605
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3533
3606
  }
3534
3607
  }
3535
3608
  // reusing CREATE_MM from the fp32 path
@@ -3548,7 +3621,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3548
3621
  s_wg_denoms = { 32, 32, 1 };
3549
3622
 
3550
3623
  CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3551
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
3624
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3552
3625
  }
3553
3626
  #undef CREATE_MM
3554
3627
 
@@ -3559,6 +3632,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3559
3632
  uint32_t rm_kq = 2;
3560
3633
  uint32_t rm_stdq_int = 1;
3561
3634
  uint32_t rm_kq_int = 1;
3635
+ auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; };
3562
3636
  if (device->vendor_id == VK_VENDOR_ID_AMD) {
3563
3637
  if (device->architecture == AMD_GCN) {
3564
3638
  rm_stdq = 2;
@@ -3662,6 +3736,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
3662
3736
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3663
3737
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3664
3738
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3739
+
3740
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_q8_1_f32", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
3741
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_q8_1_f32", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
3742
+
3665
3743
  }
3666
3744
  #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
3667
3745
  }
@@ -3708,6 +3786,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
3708
3786
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
3709
3787
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
3710
3788
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
3789
+
3790
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
3791
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
3711
3792
  }
3712
3793
  #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
3713
3794
  }
@@ -3715,6 +3796,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3715
3796
  #if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3716
3797
  GGML_UNUSED(rm_stdq_int);
3717
3798
  GGML_UNUSED(rm_kq_int);
3799
+ GGML_UNUSED(rm_iq_int);
3718
3800
  #endif
3719
3801
 
3720
3802
  // dequant shaders
@@ -3933,6 +4015,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3933
4015
  ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
3934
4016
  ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
3935
4017
  ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1);
4018
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_antialias_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS}, 1);
3936
4019
 
3937
4020
  ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3938
4021
 
@@ -3973,6 +4056,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3973
4056
  CREATE_UNARY(gelu_quick)
3974
4057
  CREATE_UNARY(silu)
3975
4058
  CREATE_UNARY(relu)
4059
+ CREATE_UNARY(xielu)
3976
4060
  CREATE_UNARY(neg)
3977
4061
  CREATE_UNARY(tanh)
3978
4062
  CREATE_UNARY(sigmoid)
@@ -4054,6 +4138,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4054
4138
 
4055
4139
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4056
4140
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4141
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4057
4142
  } else {
4058
4143
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4059
4144
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -4062,6 +4147,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4062
4147
 
4063
4148
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4064
4149
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4150
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4065
4151
  }
4066
4152
 
4067
4153
  for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
@@ -4097,10 +4183,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
4097
4183
 
4098
4184
  ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
4099
4185
 
4100
- ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
4186
+ const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
4187
+ ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
4188
+ ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size);
4189
+ ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
4190
+ ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
4101
4191
 
4102
4192
  ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
4103
4193
 
4194
+ ggml_vk_create_pipeline(device, device->pipeline_count_experts, "count_experts", count_experts_len, count_experts_data, "main", 2, sizeof(vk_op_count_experts_push_constants), {1, 1, 1}, {}, 1, true);
4195
+
4104
4196
  for (auto &s : device->pipeline_solve_tri_f32) {
4105
4197
  const vk_solve_tri_pipeline_state &state = s.first;
4106
4198
 
@@ -4251,9 +4343,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4251
4343
 
4252
4344
  for (uint32_t use_push = 0; use_push < 2; ++use_push) {
4253
4345
  for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
4254
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0, use_push}, 1, true, true, device->subgroup_size);
4255
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
4256
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
4346
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, use_push}, 1, true, true, device->subgroup_size);
4257
4347
  }
4258
4348
  }
4259
4349
 
@@ -5544,6 +5634,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
5544
5634
  case GGML_TYPE_Q4_K:
5545
5635
  case GGML_TYPE_Q5_K:
5546
5636
  case GGML_TYPE_Q6_K:
5637
+ case GGML_TYPE_IQ1_S:
5638
+ case GGML_TYPE_IQ1_M:
5547
5639
  break;
5548
5640
  default:
5549
5641
  return nullptr;
@@ -5700,6 +5792,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
5700
5792
  case GGML_TYPE_Q4_K:
5701
5793
  case GGML_TYPE_Q5_K:
5702
5794
  case GGML_TYPE_Q6_K:
5795
+ case GGML_TYPE_IQ1_S:
5796
+ case GGML_TYPE_IQ1_M:
5703
5797
  break;
5704
5798
  default:
5705
5799
  return nullptr;
@@ -5898,6 +5992,9 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
5898
5992
  std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
5899
5993
  }
5900
5994
  std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
5995
+ GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
5996
+ wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
5997
+ wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
5901
5998
  GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
5902
5999
  GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
5903
6000
  GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
@@ -6081,13 +6178,8 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
6081
6178
  }
6082
6179
  }
6083
6180
 
6084
- static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
6181
+ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
6085
6182
  VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
6086
- // Buffer is already mapped
6087
- if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
6088
- std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl;
6089
- GGML_ABORT("fatal error");
6090
- }
6091
6183
  // Check if src is pinned memory
6092
6184
  vk_buffer buf = nullptr;
6093
6185
  size_t buf_offset = 0;
@@ -6112,12 +6204,13 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
6112
6204
 
6113
6205
  ggml_vk_sync_buffers(nullptr, subctx);
6114
6206
  subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
6115
- return;
6207
+ return true;
6116
6208
  }
6117
6209
  VK_LOG_DEBUG("STAGING");
6118
6210
 
6119
6211
  if (!sync_staging) {
6120
- GGML_ABORT("Asynchronous write to non-pinned memory not supported");
6212
+ // copy was not handled caller needs to fall back
6213
+ return false;
6121
6214
  }
6122
6215
 
6123
6216
  // Staging buffer required
@@ -6141,9 +6234,10 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
6141
6234
  deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
6142
6235
  }
6143
6236
  }
6237
+ return true;
6144
6238
  }
6145
6239
 
6146
- static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
6240
+ static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
6147
6241
  VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
6148
6242
  return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
6149
6243
  }
@@ -6162,7 +6256,8 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
6162
6256
 
6163
6257
  vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
6164
6258
  ggml_vk_ctx_begin(dst->device, subctx);
6165
- ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
6259
+ bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
6260
+ GGML_ASSERT(ret);
6166
6261
  ggml_vk_ctx_end(subctx);
6167
6262
 
6168
6263
  for (auto& cpy : subctx->in_memcpys) {
@@ -6497,18 +6592,18 @@ static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context *
6497
6592
 
6498
6593
  static void ggml_vk_matmul_id(
6499
6594
  ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
6500
- vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
6595
+ vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, const vk_subbuffer & expert_count_buf,
6501
6596
  uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
6502
6597
  uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
6503
6598
  uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
6504
6599
  uint32_t padded_n) {
6505
- VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
6600
+ VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), expert_count: (" << expert_count_buf.buffer->buffer << ", " << expert_count_buf.offset << ", " << expert_count_buf.size << "), " <<
6506
6601
  "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
6507
6602
  "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
6508
6603
  "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
6509
6604
  const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
6510
6605
  nei0, nei1, nbi1, ne11, padded_n };
6511
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as });
6606
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids, expert_count_buf }, pc, { m, nei1, n_as });
6512
6607
  }
6513
6608
 
6514
6609
  static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
@@ -6680,7 +6775,12 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
6680
6775
 
6681
6776
  vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
6682
6777
 
6683
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
6778
+ const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]);
6779
+ // clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks.
6780
+ const uint64_t max_elements = std::min<uint64_t>(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits<uint32_t>::max());
6781
+ const uint32_t elements = std::min(ne, static_cast<uint32_t>(max_elements));
6782
+
6783
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 2>{ ne, num_blocks }, { elements, 1, 1 });
6684
6784
  ggml_vk_sync_buffers(ctx, subctx);
6685
6785
  }
6686
6786
 
@@ -6964,7 +7064,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
6964
7064
  // Quantization overhead is not worth it for small k
6965
7065
  switch (device->vendor_id) {
6966
7066
  case VK_VENDOR_ID_NVIDIA:
6967
- if (src0_type == GGML_TYPE_Q2_K) {
7067
+ if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
6968
7068
  return true;
6969
7069
  }
6970
7070
 
@@ -7491,6 +7591,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
7491
7591
  const uint64_t nei0 = ids->ne[0];
7492
7592
  const uint64_t nei1 = ids->ne[1];
7493
7593
 
7594
+ const uint32_t nbi0 = ids->nb[0];
7494
7595
  const uint32_t nbi1 = ids->nb[1];
7495
7596
  const uint32_t nbi2 = ids->nb[2];
7496
7597
 
@@ -7598,6 +7699,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
7598
7699
  if (quantize_y) {
7599
7700
  to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
7600
7701
  }
7702
+ vk_pipeline count_experts = ctx->device->pipeline_count_experts;
7703
+
7704
+ uint32_t expert_count_size = sizeof(uint32_t) * n_as;
7601
7705
 
7602
7706
  {
7603
7707
  if (
@@ -7613,6 +7717,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
7613
7717
  ctx->prealloc_size_y = y_sz;
7614
7718
  ggml_vk_preallocate_buffers(ctx, subctx);
7615
7719
  }
7720
+ if (ctx->prealloc_size_split_k < expert_count_size) {
7721
+ ctx->prealloc_size_split_k = expert_count_size;
7722
+ ggml_vk_preallocate_buffers(ctx, subctx);
7723
+ }
7616
7724
 
7617
7725
  // Request descriptor sets
7618
7726
  ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
@@ -7625,6 +7733,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
7625
7733
  if (quantize_y) {
7626
7734
  ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
7627
7735
  }
7736
+ ggml_pipeline_request_descriptor_sets(ctx, count_experts, 1);
7628
7737
  }
7629
7738
 
7630
7739
  vk_buffer d_D = dst_buf_ctx->dev_buffer;
@@ -7674,6 +7783,20 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
7674
7783
  ggml_vk_sync_buffers(ctx, subctx);
7675
7784
  }
7676
7785
  }
7786
+ // Count how many times each expert is used
7787
+ vk_subbuffer expert_count_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
7788
+ if (ctx->prealloc_split_k_need_sync) {
7789
+ ggml_vk_sync_buffers(ctx, subctx);
7790
+ }
7791
+ {
7792
+ const std::vector<uint32_t> pc = { (uint32_t)nei0,
7793
+ (uint32_t)nei1,
7794
+ (uint32_t)(nbi0 / ggml_type_size(ids->type)),
7795
+ (uint32_t)(nbi1 / ggml_type_size(ids->type)),
7796
+ (uint32_t)(get_misalign_bytes(ctx, ids) / ggml_type_size(ids->type)) };
7797
+ ggml_vk_dispatch_pipeline(ctx, subctx, count_experts,
7798
+ { vk_subbuffer{ d_ids, ids_buf_offset, ids_sz }, expert_count_buf }, pc, { (uint32_t)n_as, 1, 1});
7799
+ }
7677
7800
 
7678
7801
  if (x_non_contig) {
7679
7802
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
@@ -7681,7 +7804,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
7681
7804
  const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
7682
7805
  ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
7683
7806
  { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)x_ne, 1, 1});
7684
- ggml_vk_sync_buffers(ctx, subctx);
7685
7807
  }
7686
7808
  if (y_non_contig) {
7687
7809
  if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
@@ -7705,6 +7827,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
7705
7827
  ctx->prealloc_y_last_tensor_used = src1;
7706
7828
  }
7707
7829
  }
7830
+ ggml_vk_sync_buffers(ctx, subctx);
7708
7831
 
7709
7832
  uint32_t stride_batch_x = ne00*ne01;
7710
7833
  uint32_t stride_batch_y = ne10*ne11;
@@ -7721,7 +7844,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
7721
7844
  ggml_vk_matmul_id(
7722
7845
  ctx, subctx, pipeline,
7723
7846
  { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
7724
- { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz },
7847
+ { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,
7725
7848
  ne01, ne21, ne10, ne10, ne10, ne01,
7726
7849
  stride_batch_x, stride_batch_y, ne20*ne21,
7727
7850
  n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
@@ -7733,6 +7856,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
7733
7856
  if (y_non_contig || quantize_y) {
7734
7857
  ctx->prealloc_y_need_sync = true;
7735
7858
  }
7859
+ ctx->prealloc_split_k_need_sync = true;
7736
7860
  }
7737
7861
 
7738
7862
  static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@@ -7982,11 +8106,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
7982
8106
  }
7983
8107
  }
7984
8108
 
7985
- static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
8109
+ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
7986
8110
  // Needs to be kept up to date on shader changes
7987
8111
  GGML_UNUSED(hsv);
7988
8112
  const uint32_t wg_size = scalar_flash_attention_workgroup_size;
7989
- const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
8113
+ const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
7990
8114
  const uint32_t Bc = scalar_flash_attention_Bc;
7991
8115
 
7992
8116
  const uint32_t tmpsh = wg_size * sizeof(float);
@@ -8110,6 +8234,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8110
8234
  uint32_t workgroups_y = (uint32_t)neq2;
8111
8235
  uint32_t workgroups_z = (uint32_t)neq3;
8112
8236
 
8237
+ const bool small_cache = nek1 < 1024;
8238
+
8113
8239
  // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
8114
8240
  // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
8115
8241
  uint32_t max_gqa;
@@ -8117,7 +8243,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8117
8243
  case FA_SCALAR:
8118
8244
  case FA_COOPMAT1:
8119
8245
  // We may switch from coopmat1 to scalar, so use the scalar limit for both
8120
- max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
8246
+ max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
8121
8247
  break;
8122
8248
  case FA_COOPMAT2:
8123
8249
  max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
@@ -8151,7 +8277,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8151
8277
 
8152
8278
  // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
8153
8279
  if (path == FA_SCALAR &&
8154
- !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
8280
+ !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
8155
8281
  small_rows = true;
8156
8282
  }
8157
8283
 
@@ -8167,7 +8293,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8167
8293
  v_stride /= 4;
8168
8294
  }
8169
8295
 
8170
- uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
8296
+ uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
8171
8297
  bool aligned = (KV % alignment) == 0 &&
8172
8298
  // the "aligned" shader variant will forcibly align strides, for performance
8173
8299
  (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
@@ -8179,7 +8305,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8179
8305
 
8180
8306
  bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
8181
8307
 
8182
- vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc);
8308
+ vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc);
8183
8309
 
8184
8310
  vk_pipeline pipeline = nullptr;
8185
8311
 
@@ -8404,7 +8530,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
8404
8530
  return nullptr;
8405
8531
  case GGML_OP_UPSCALE:
8406
8532
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8407
- ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(dst, 0) & 0xFF);
8533
+ uint32_t mode = (ggml_get_op_params_i32(dst, 0) & (0xFF | GGML_SCALE_FLAG_ANTIALIAS));
8408
8534
  switch (mode) {
8409
8535
  case GGML_SCALE_MODE_NEAREST:
8410
8536
  return ctx->device->pipeline_upscale_nearest_f32;
@@ -8412,6 +8538,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
8412
8538
  return ctx->device->pipeline_upscale_bilinear_f32;
8413
8539
  case GGML_SCALE_MODE_BICUBIC:
8414
8540
  return ctx->device->pipeline_upscale_bicubic_f32;
8541
+ case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS:
8542
+ return ctx->device->pipeline_upscale_bilinear_antialias_f32;
8415
8543
  default:
8416
8544
  return nullptr;
8417
8545
  }
@@ -8549,6 +8677,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
8549
8677
  return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
8550
8678
  case GGML_UNARY_OP_RELU:
8551
8679
  return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
8680
+ case GGML_UNARY_OP_XIELU:
8681
+ return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16];
8552
8682
  case GGML_UNARY_OP_NEG:
8553
8683
  return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
8554
8684
  case GGML_UNARY_OP_TANH:
@@ -8613,10 +8743,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
8613
8743
  if (ctx->num_additional_fused_ops) {
8614
8744
  uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
8615
8745
  GGML_ASSERT(idx < num_topk_moe_pipelines);
8616
- topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
8617
8746
  // use n_experts from push constant if it's not equal to the power of two spec constant
8618
8747
  bool use_push = dst->ne[0] != (1u << idx);
8619
- return ctx->device->pipeline_topk_moe[idx][mode][use_push];
8748
+ return ctx->device->pipeline_topk_moe[idx][use_push];
8620
8749
  }
8621
8750
 
8622
8751
  if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@@ -8654,6 +8783,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
8654
8783
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8655
8784
  return ctx->device->pipeline_rope_multi_f32;
8656
8785
  }
8786
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
8787
+ return ctx->device->pipeline_rope_multi_f32_f16;
8788
+ }
8657
8789
  if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
8658
8790
  return ctx->device->pipeline_rope_multi_f16;
8659
8791
  }
@@ -8686,7 +8818,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
8686
8818
  return nullptr;
8687
8819
  case GGML_OP_CUMSUM:
8688
8820
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8689
- return ctx->device->pipeline_cumsum_f32;
8821
+ if (src0->ne[0] <= 512) {
8822
+ return ctx->device->pipeline_cumsum_small_f32;
8823
+ } else {
8824
+ return ctx->device->pipeline_cumsum_f32;
8825
+ }
8690
8826
  }
8691
8827
  return nullptr;
8692
8828
  case GGML_OP_SOLVE_TRI:
@@ -9057,10 +9193,20 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
9057
9193
  elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
9058
9194
  } break;
9059
9195
  case GGML_OP_DIAG_MASK_INF:
9060
- case GGML_OP_ROPE:
9061
- case GGML_OP_ROPE_BACK:
9062
9196
  elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
9063
9197
  break;
9198
+ case GGML_OP_ROPE:
9199
+ case GGML_OP_ROPE_BACK:
9200
+ {
9201
+ uint32_t nrows = (uint32_t)ggml_nrows(src0);
9202
+ uint32_t z = 1;
9203
+ if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
9204
+ z = CEIL_DIV(nrows, 32768);
9205
+ nrows = 32768;
9206
+ }
9207
+ elements = { nrows, (uint32_t)ne00, z };
9208
+
9209
+ } break;
9064
9210
  case GGML_OP_GET_ROWS:
9065
9211
  elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
9066
9212
  elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
@@ -9084,6 +9230,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
9084
9230
  const uint32_t batch = src1->ne[is_2D ? 3 : 2];
9085
9231
 
9086
9232
  elements = { OW * KW * KH, OH, batch * IC };
9233
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
9234
+ elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
9087
9235
  } break;
9088
9236
  case GGML_OP_IM2COL_3D:
9089
9237
  {
@@ -9695,14 +9843,14 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
9695
9843
 
9696
9844
  ggml_vk_op_f32_opt_step_adamw(
9697
9845
  ctx, subctx, dst,
9698
- { (uint32_t)n, 0, 0.0f, 0.0f }
9846
+ { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }
9699
9847
  );
9700
9848
  }
9701
9849
 
9702
9850
  static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
9703
9851
  const size_t n = ggml_nelements(dst->src[0]);
9704
9852
 
9705
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f });
9853
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f });
9706
9854
  }
9707
9855
 
9708
9856
  static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9788,6 +9936,7 @@ static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, gg
9788
9936
  1,
9789
9937
  ggml_get_op_params_f32(dst, 0),
9790
9938
  ggml_get_op_params_f32(dst, 2),
9939
+ 0.0f, 0.0f,
9791
9940
  };
9792
9941
 
9793
9942
  vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
@@ -9809,6 +9958,7 @@ static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml
9809
9958
  1,
9810
9959
  ggml_get_op_params_f32(dst, 0),
9811
9960
  0.0f,
9961
+ 0.0f, 0.0f,
9812
9962
  };
9813
9963
 
9814
9964
  vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
@@ -9924,13 +10074,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
9924
10074
  }
9925
10075
 
9926
10076
  static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9927
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
10077
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
9928
10078
  }
9929
10079
 
9930
10080
  static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9931
10081
  float * op_params = (float *)dst->op_params;
9932
10082
 
9933
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
10083
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
9934
10084
  }
9935
10085
 
9936
10086
  static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -9941,7 +10091,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
9941
10091
  const float eps = float_op_params[1];
9942
10092
  const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
9943
10093
 
9944
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f });
10094
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f });
9945
10095
  }
9946
10096
 
9947
10097
  static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
@@ -9984,7 +10134,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
9984
10134
  uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
9985
10135
 
9986
10136
  vk_op_rope_push_constants rope {
9987
- (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
10137
+ (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
9988
10138
  freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
9989
10139
  has_ff, (uint32_t)src0->ne[2], nb01, nb02,
9990
10140
  { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
@@ -10110,16 +10260,26 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
10110
10260
 
10111
10261
  static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10112
10262
  float * op_params = (float *)dst->op_params;
10113
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
10263
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
10114
10264
  }
10115
10265
 
10116
10266
  static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10117
10267
  float * op_params = (float *)dst->op_params;
10118
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
10268
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
10119
10269
  }
10120
10270
 
10121
10271
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10122
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
10272
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
10273
+ }
10274
+
10275
+ static void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10276
+ float * op_params = (float *)dst->op_params;
10277
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY,
10278
+ {
10279
+ (uint32_t)ggml_nelements(src0), 0,
10280
+ op_params[1], op_params[2], op_params[3], op_params[4]
10281
+ }
10282
+ );
10123
10283
  }
10124
10284
 
10125
10285
  static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -10244,18 +10404,20 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
10244
10404
 
10245
10405
  static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10246
10406
  float * op_params = (float *)dst->op_params;
10247
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] });
10407
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f });
10248
10408
  }
10249
10409
 
10250
10410
  static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
10251
- topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
10411
+ topk_moe_mode mode = ctx->fused_topk_moe_mode;
10252
10412
  ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
10253
- ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
10254
- (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
10255
- cgraph->nodes[node_idx + 5];
10256
- ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
10413
+ ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
10414
+ ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
10415
+ ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
10416
+ (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] :
10417
+ cgraph->nodes[node_idx + 3];
10257
10418
 
10258
10419
  GGML_ASSERT(logits->type == GGML_TYPE_F32);
10420
+ GGML_ASSERT(bias->type == GGML_TYPE_F32);
10259
10421
  GGML_ASSERT(weights->type == GGML_TYPE_F32);
10260
10422
  GGML_ASSERT(ids->type == GGML_TYPE_I32);
10261
10423
 
@@ -10270,6 +10432,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
10270
10432
  ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10271
10433
 
10272
10434
  vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
10435
+ vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);
10273
10436
  vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
10274
10437
  vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
10275
10438
 
@@ -10277,18 +10440,45 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
10277
10440
  pc.n_rows = n_rows;
10278
10441
  pc.n_experts_push = n_experts;
10279
10442
  pc.n_expert_used = n_expert_used;
10443
+ pc.clamp_min = -std::numeric_limits<float>::infinity();
10444
+ pc.clamp_max = std::numeric_limits<float>::infinity();
10280
10445
  if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
10281
10446
  ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
10447
+ GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
10448
+ pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
10449
+ pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
10450
+ }
10451
+ if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
10452
+ ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
10453
+ GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
10282
10454
  pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
10283
10455
  pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
10284
10456
  }
10285
10457
 
10458
+ #define GATING_FUNC_SOFTMAX 0
10459
+ #define GATING_FUNC_SIGMOID 1
10460
+ #define GATING_FUNC_SOFTMAX_WEIGHT 2
10461
+
10462
+ pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
10463
+ mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT :
10464
+ GATING_FUNC_SOFTMAX;
10465
+ pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
10466
+ pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
10467
+ if (ctx->fused_topk_moe_scale) {
10468
+ GGML_ASSERT(weights->op == GGML_OP_SCALE);
10469
+ pc.output_scale = ggml_get_op_params_f32(weights, 0);
10470
+ pc.output_bias = ggml_get_op_params_f32(weights, 1);
10471
+ } else {
10472
+ pc.output_scale = 1.0f;
10473
+ pc.output_bias = 0.0f;
10474
+ }
10475
+
10286
10476
  GGML_ASSERT(n_expert_used <= n_experts);
10287
10477
 
10288
10478
  const uint32_t rows_per_block = 4;
10289
10479
  std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
10290
10480
 
10291
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements);
10481
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements);
10292
10482
  }
10293
10483
 
10294
10484
  static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
@@ -10536,16 +10726,58 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
10536
10726
  }
10537
10727
 
10538
10728
  static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10539
- vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
10540
- ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
10729
+ vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
10730
+ // Use the single pass shader when the rows are small or there are enough rows to fill the GPU.
10731
+ // For fewer, larger rows, use the multipass shader to spread each row across SMs.
10732
+ if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) {
10733
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc);
10734
+ return;
10735
+ }
10736
+
10737
+ // First pass computes partial sums within a block, and stores the last partial
10738
+ // to the temp buffer. Second pass sums the block partials from the temp buffer
10739
+ // and adds that to the result of the first pass.
10740
+ vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32;
10741
+ vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32;
10742
+ GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr);
10743
+
10744
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
10745
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
10746
+
10747
+ std::array<uint32_t, 3> elements;
10748
+
10749
+ elements[0] = dst->ne[0];
10750
+ elements[1] = (uint32_t)ggml_nrows(dst);
10751
+ elements[2] = 1;
10752
+
10753
+ size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst);
10754
+
10755
+ if (ctx->prealloc_size_split_k < temp_size) {
10756
+ ctx->prealloc_size_split_k = temp_size;
10757
+ ggml_vk_preallocate_buffers(ctx, subctx);
10758
+ }
10759
+
10760
+ vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
10761
+ vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
10762
+ vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
10763
+
10764
+ if (ctx->prealloc_split_k_need_sync) {
10765
+ ggml_vk_sync_buffers(ctx, subctx);
10766
+ }
10767
+
10768
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements);
10769
+ ggml_vk_sync_buffers(ctx, subctx);
10770
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements);
10771
+
10772
+ ctx->prealloc_split_k_need_sync = true;
10541
10773
  }
10542
10774
 
10543
10775
  static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10544
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
10776
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f });
10545
10777
  }
10546
10778
 
10547
10779
  static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10548
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
10780
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
10549
10781
  }
10550
10782
 
10551
10783
  static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -10587,6 +10819,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
10587
10819
  const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
10588
10820
 
10589
10821
  const uint32_t pelements = OW * KW * KH;
10822
+ const uint32_t batch = src1->ne[is_2D ? 3 : 2];
10590
10823
 
10591
10824
  const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
10592
10825
  const vk_buffer d_buf = d_buf_ctx->dev_buffer;
@@ -10599,7 +10832,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
10599
10832
  IC, IW, IH, OW, OH, KW, KH,
10600
10833
  pelements,
10601
10834
  IC * KH * KW,
10602
- s0, s1, p0, p1, d0, d1,
10835
+ s0, s1, p0, p1, d0, d1, batch * IC
10603
10836
  });
10604
10837
  }
10605
10838
 
@@ -10804,7 +11037,7 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
10804
11037
 
10805
11038
  static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10806
11039
  const float * op_params = (const float *)dst->op_params;
10807
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
11040
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
10808
11041
  }
10809
11042
 
10810
11043
  #ifdef GGML_VULKAN_RUN_TESTS
@@ -12029,6 +12262,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
12029
12262
 
12030
12263
  break;
12031
12264
  case GGML_OP_UNARY:
12265
+ if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
12266
+ ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
12267
+ break;
12268
+ }
12269
+
12032
12270
  switch (ggml_get_unary_op(node)) {
12033
12271
  case GGML_UNARY_OP_EXP:
12034
12272
  case GGML_UNARY_OP_SILU:
@@ -12050,6 +12288,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
12050
12288
  case GGML_UNARY_OP_TRUNC:
12051
12289
  ggml_vk_unary(ctx, compute_ctx, src0, node);
12052
12290
  break;
12291
+ case GGML_UNARY_OP_XIELU:
12292
+ ggml_vk_xielu(ctx, compute_ctx, src0, node);
12293
+ break;
12053
12294
  default:
12054
12295
  return false;
12055
12296
  }
@@ -12073,7 +12314,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
12073
12314
 
12074
12315
  break;
12075
12316
  case GGML_OP_SOFT_MAX:
12076
- if (ctx->num_additional_fused_ops) {
12317
+ if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
12077
12318
  ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
12078
12319
  } else {
12079
12320
  ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);
@@ -12093,7 +12334,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
12093
12334
 
12094
12335
  break;
12095
12336
  case GGML_OP_ARGSORT:
12096
- if (ctx->num_additional_fused_ops) {
12337
+ if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
12097
12338
  ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
12098
12339
  } else {
12099
12340
  ggml_vk_argsort(ctx, compute_ctx, src0, node);
@@ -12643,7 +12884,23 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
12643
12884
 
12644
12885
  vk_buffer buf = buf_ctx->dev_buffer;
12645
12886
 
12646
- ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
12887
+ auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
12888
+
12889
+ bool ret = ggml_vk_buffer_write_async(transfer_ctx, buf, dst_offset, data, size);
12890
+
12891
+ if (!ret) {
12892
+ ggml_vk_ensure_sync_staging_buffer(ctx, size);
12893
+ ggml_vk_sync_buffers(nullptr, transfer_ctx);
12894
+
12895
+ vk::BufferCopy buffer_cpy;
12896
+ buffer_cpy.srcOffset = 0;
12897
+ buffer_cpy.dstOffset = dst_offset;
12898
+ buffer_cpy.size = size;
12899
+
12900
+ transfer_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
12901
+ deferred_memcpy(ctx->sync_staging->ptr, data, size, &transfer_ctx->in_memcpys);
12902
+ ggml_vk_synchronize(ctx);
12903
+ }
12647
12904
  }
12648
12905
 
12649
12906
  static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@@ -12920,42 +13177,81 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
12920
13177
 
12921
13178
  const ggml_tensor * softmax;
12922
13179
  const ggml_tensor * weights;
13180
+ const ggml_tensor * get_rows;
13181
+ const ggml_tensor * argsort;
12923
13182
 
12924
13183
  switch (mode) {
12925
13184
  case TOPK_MOE_EARLY_SOFTMAX_NORM:
12926
13185
  softmax = cgraph->nodes[node_idx + 0];
12927
13186
  weights = cgraph->nodes[node_idx + 9];
13187
+ get_rows = cgraph->nodes[node_idx + 4];
13188
+ argsort = cgraph->nodes[node_idx + 2];
13189
+ break;
13190
+ case TOPK_MOE_SIGMOID_NORM_BIAS:
13191
+ softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
13192
+ weights = cgraph->nodes[node_idx + 10];
13193
+ get_rows = cgraph->nodes[node_idx + 5];
13194
+ argsort = cgraph->nodes[node_idx + 3];
13195
+ if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {
13196
+ return false;
13197
+ }
13198
+ // bias is expected to be 1D
13199
+ if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
13200
+ !ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
13201
+ return false;
13202
+ }
13203
+ // sigmoid fusion seems to generate infinities on moltenvk
13204
+ if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {
13205
+ return false;
13206
+ }
12928
13207
  break;
12929
13208
  case TOPK_MOE_EARLY_SOFTMAX:
12930
13209
  softmax = cgraph->nodes[node_idx + 0];
12931
13210
  weights = cgraph->nodes[node_idx + 4];
13211
+ get_rows = cgraph->nodes[node_idx + 4];
13212
+ argsort = cgraph->nodes[node_idx + 2];
12932
13213
  break;
12933
13214
  case TOPK_MOE_LATE_SOFTMAX:
12934
13215
  softmax = cgraph->nodes[node_idx + 4];
12935
13216
  weights = cgraph->nodes[node_idx + 5];
13217
+ get_rows = cgraph->nodes[node_idx + 2];
13218
+ argsort = cgraph->nodes[node_idx + 0];
12936
13219
  break;
12937
13220
  default:
12938
13221
  return false;
12939
13222
  }
12940
13223
 
12941
- const float * op_params = (const float *)softmax->op_params;
12942
-
12943
- float scale = op_params[0];
12944
- float max_bias = op_params[1];
12945
-
12946
- if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
13224
+ ggml_tensor * probs = get_rows->src[0];
13225
+ if (probs->op != GGML_OP_RESHAPE) {
12947
13226
  return false;
12948
13227
  }
13228
+ probs = probs->src[0];
13229
+ ggml_tensor * selection_probs = argsort->src[0];
12949
13230
 
12950
- if (scale != 1.0f || max_bias != 0.0f) {
13231
+ if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {
12951
13232
  return false;
12952
13233
  }
12953
13234
 
12954
- // don't fuse when masks or sinks are present
12955
- if (softmax->src[1] || softmax->src[2]) {
13235
+ if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
12956
13236
  return false;
12957
13237
  }
12958
13238
 
13239
+ if (softmax->op == GGML_OP_SOFT_MAX) {
13240
+ const float * op_params = (const float *)softmax->op_params;
13241
+
13242
+ float scale = op_params[0];
13243
+ float max_bias = op_params[1];
13244
+
13245
+ if (scale != 1.0f || max_bias != 0.0f) {
13246
+ return false;
13247
+ }
13248
+
13249
+ // don't fuse when masks or sinks are present
13250
+ if (softmax->src[1] || softmax->src[2]) {
13251
+ return false;
13252
+ }
13253
+ }
13254
+
12959
13255
  const int n_expert = softmax->ne[0];
12960
13256
  if (n_expert > (1 << (num_topk_moe_pipelines-1))) {
12961
13257
  return false;
@@ -12997,9 +13293,9 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
12997
13293
  return false;
12998
13294
  }
12999
13295
 
13000
- // Only norm/neox shaders have the fusion code
13296
+ // Only norm/neox/mrope shaders have the fusion code
13001
13297
  const int mode = ((const int32_t *) rope->op_params)[2];
13002
- if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
13298
+ if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_MROPE) {
13003
13299
  return false;
13004
13300
  }
13005
13301
 
@@ -13226,6 +13522,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13226
13522
  total_mul_mat_bytes += bytes;
13227
13523
  }
13228
13524
 
13525
+ ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
13526
+ ctx->fused_topk_moe_scale = false;
13229
13527
  const char *fusion_string {};
13230
13528
  if (!ctx->device->disable_fusion) {
13231
13529
  uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
@@ -13271,13 +13569,23 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13271
13569
  ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
13272
13570
  // view of argsort writes to memory
13273
13571
  ctx->fused_ops_write_mask |= 1 << 3;
13572
+ ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
13274
13573
  fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
13574
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
13575
+ ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
13576
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
13577
+ ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;
13578
+ // view of argsort writes to memory
13579
+ ctx->fused_ops_write_mask |= 1 << 4;
13580
+ ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
13581
+ fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
13275
13582
  } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
13276
13583
  ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
13277
13584
  ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
13278
13585
  ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
13279
13586
  // view of argsort writes to memory
13280
13587
  ctx->fused_ops_write_mask |= 1 << 3;
13588
+ ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
13281
13589
  fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
13282
13590
  } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
13283
13591
  ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
@@ -13285,8 +13593,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13285
13593
  ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
13286
13594
  // view of argsort writes to memory
13287
13595
  ctx->fused_ops_write_mask |= 1 << 1;
13596
+ ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
13288
13597
  fusion_string = "TOPK_MOE_LATE_SOFTMAX";
13289
13598
  }
13599
+ if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
13600
+ // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
13601
+ if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
13602
+ ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
13603
+ ctx->fused_topk_moe_scale = true;
13604
+ ctx->num_additional_fused_ops++;
13605
+ }
13606
+ }
13290
13607
  }
13291
13608
  ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
13292
13609
 
@@ -13465,6 +13782,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
13465
13782
  if (keep_pattern(topk_moe_early_softmax_norm)) {
13466
13783
  continue;
13467
13784
  }
13785
+ if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
13786
+ continue;
13787
+ }
13468
13788
  if (keep_pattern(topk_moe_early_softmax)) {
13469
13789
  continue;
13470
13790
  }
@@ -13491,6 +13811,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
13491
13811
  }
13492
13812
  // Don't pull forward nodes from fusion patterns
13493
13813
  if (match_pattern(topk_moe_early_softmax_norm, j) ||
13814
+ match_pattern(topk_moe_sigmoid_norm_bias, j) ||
13494
13815
  match_pattern(topk_moe_early_softmax, j) ||
13495
13816
  match_pattern(topk_moe_late_softmax, j)) {
13496
13817
  continue;
@@ -13502,7 +13823,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
13502
13823
  !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
13503
13824
  !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
13504
13825
  !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
13505
- !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) {
13826
+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
13827
+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
13506
13828
  ok = false;
13507
13829
  break;
13508
13830
  }
@@ -13630,11 +13952,62 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
13630
13952
  }
13631
13953
  }
13632
13954
 
13955
+ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
13956
+ VK_LOG_DEBUG("ggml_backend_vk_event_record(backend=" << backend << ", event=" << event << ")");
13957
+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13958
+ vk_event *vkev = (vk_event *)event->context;
13959
+
13960
+ vk_context transfer_ctx;
13961
+
13962
+ if (ctx->transfer_ctx.expired()) {
13963
+ // Initialize new transfer context
13964
+ transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13965
+ ctx->transfer_ctx = transfer_ctx;
13966
+ ggml_vk_ctx_begin(ctx->device, transfer_ctx);
13967
+ } else {
13968
+ transfer_ctx = ctx->transfer_ctx.lock();
13969
+ }
13970
+
13971
+ // the backend interface doesn't have an explicit reset, so reset it here
13972
+ // before we record the command to set it
13973
+ ctx->device->device.resetEvent(vkev->event);
13974
+ ctx->device->device.resetFences({ vkev->fence });
13975
+
13976
+ ggml_vk_set_event(transfer_ctx, vkev->event);
13977
+
13978
+ ggml_vk_ctx_end(transfer_ctx);
13979
+
13980
+ ggml_vk_submit(transfer_ctx, {vkev->fence});
13981
+ ctx->submit_pending = true;
13982
+ ctx->transfer_ctx.reset();
13983
+ }
13984
+
13985
+ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
13986
+ VK_LOG_DEBUG("ggml_backend_vk_event_wait(backend=" << backend << ", event=" << event << ")");
13987
+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13988
+ vk_event *vkev = (vk_event *)event->context;
13989
+
13990
+ vk_context transfer_ctx;
13991
+
13992
+ if (ctx->transfer_ctx.expired()) {
13993
+ // Initialize new transfer context
13994
+ transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13995
+ ctx->transfer_ctx = transfer_ctx;
13996
+ ggml_vk_ctx_begin(ctx->device, transfer_ctx);
13997
+ } else {
13998
+ transfer_ctx = ctx->transfer_ctx.lock();
13999
+ }
14000
+
14001
+ ggml_vk_wait_events(transfer_ctx, {vkev->event});
14002
+ ggml_vk_ctx_end(transfer_ctx);
14003
+ ctx->transfer_ctx.reset();
14004
+ }
14005
+
13633
14006
  // TODO: enable async and synchronize
13634
14007
  static ggml_backend_i ggml_backend_vk_interface = {
13635
14008
  /* .get_name = */ ggml_backend_vk_name,
13636
14009
  /* .free = */ ggml_backend_vk_free,
13637
- /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
14010
+ /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async,
13638
14011
  /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async,
13639
14012
  /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
13640
14013
  /* .synchronize = */ ggml_backend_vk_synchronize,
@@ -13643,8 +14016,8 @@ static ggml_backend_i ggml_backend_vk_interface = {
13643
14016
  /* .graph_plan_update = */ NULL,
13644
14017
  /* .graph_plan_compute = */ NULL,
13645
14018
  /* .graph_compute = */ ggml_backend_vk_graph_compute,
13646
- /* .event_record = */ NULL,
13647
- /* .event_wait = */ NULL,
14019
+ /* .event_record = */ ggml_backend_vk_event_record,
14020
+ /* .event_wait = */ ggml_backend_vk_event_wait,
13648
14021
  /* .graph_optimize = */ ggml_vk_graph_optimize,
13649
14022
  };
13650
14023
 
@@ -13819,10 +14192,10 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
13819
14192
  props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
13820
14193
  ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
13821
14194
  props->caps = {
13822
- /* .async = */ false,
14195
+ /* .async = */ true,
13823
14196
  /* .host_buffer = */ true,
13824
14197
  /* .buffer_from_host_ptr = */ false,
13825
- /* .events = */ false,
14198
+ /* .events = */ true,
13826
14199
  };
13827
14200
  }
13828
14201
 
@@ -13842,6 +14215,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
13842
14215
  case GGML_UNARY_OP_GELU_QUICK:
13843
14216
  case GGML_UNARY_OP_SILU:
13844
14217
  case GGML_UNARY_OP_RELU:
14218
+ case GGML_UNARY_OP_XIELU:
13845
14219
  case GGML_UNARY_OP_NEG:
13846
14220
  case GGML_UNARY_OP_TANH:
13847
14221
  case GGML_UNARY_OP_SIGMOID:
@@ -14191,7 +14565,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
14191
14565
  }
14192
14566
  return true;
14193
14567
  case GGML_OP_UPSCALE:
14194
- return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
14568
+ if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
14569
+ if ((op->op_params[0] & 0xFF) != GGML_SCALE_MODE_BILINEAR) {
14570
+ return false;
14571
+ }
14572
+ }
14573
+ return op->src[0]->type == GGML_TYPE_F32;
14195
14574
  case GGML_OP_ACC:
14196
14575
  return op->src[0]->type == GGML_TYPE_F32;
14197
14576
  case GGML_OP_CONCAT:
@@ -14353,6 +14732,47 @@ static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml
14353
14732
  UNUSED(dev);
14354
14733
  }
14355
14734
 
14735
+ static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {
14736
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14737
+ auto device = ggml_vk_get_device(ctx->device);
14738
+
14739
+ vk_event *vkev = new vk_event;
14740
+ if (!vkev) {
14741
+ return nullptr;
14742
+ }
14743
+
14744
+ // The event/fence is expected to initially be in the signaled state.
14745
+ vkev->event = device->device.createEvent({});
14746
+ vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled});
14747
+ device->device.setEvent(vkev->event);
14748
+
14749
+ return new ggml_backend_event {
14750
+ /* .device = */ dev,
14751
+ /* .context = */ vkev,
14752
+ };
14753
+ }
14754
+
14755
+ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
14756
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14757
+ auto device = ggml_vk_get_device(ctx->device);
14758
+
14759
+ vk_event *vkev = (vk_event *)event->context;
14760
+
14761
+ device->device.destroyFence(vkev->fence);
14762
+ device->device.destroyEvent(vkev->event);
14763
+ delete vkev;
14764
+ delete event;
14765
+ }
14766
+
14767
+ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
14768
+ VK_LOG_DEBUG("ggml_backend_vk_device_event_synchronize(backend=" << dev << ", event=" << event << ")");
14769
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14770
+ auto device = ggml_vk_get_device(ctx->device);
14771
+ vk_event *vkev = (vk_event *)event->context;
14772
+
14773
+ VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
14774
+ }
14775
+
14356
14776
  static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
14357
14777
  /* .get_name = */ ggml_backend_vk_device_get_name,
14358
14778
  /* .get_description = */ ggml_backend_vk_device_get_description,
@@ -14366,9 +14786,9 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
14366
14786
  /* .supports_op = */ ggml_backend_vk_device_supports_op,
14367
14787
  /* .supports_buft = */ ggml_backend_vk_device_supports_buft,
14368
14788
  /* .offload_op = */ ggml_backend_vk_device_offload_op,
14369
- /* .event_new = */ NULL,
14370
- /* .event_free = */ NULL,
14371
- /* .event_synchronize = */ NULL,
14789
+ /* .event_new = */ ggml_backend_vk_device_event_new,
14790
+ /* .event_free = */ ggml_backend_vk_device_event_free,
14791
+ /* .event_synchronize = */ ggml_backend_vk_device_event_synchronize,
14372
14792
  };
14373
14793
 
14374
14794
  static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
@@ -14747,7 +15167,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
14747
15167
  } else if (tensor->op == GGML_OP_LOG) {
14748
15168
  tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
14749
15169
  } else if (tensor->op == GGML_OP_TRI) {
14750
- tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
15170
+ tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0));
14751
15171
  } else if (tensor->op == GGML_OP_DIAG) {
14752
15172
  tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
14753
15173
  } else if (tensor->op == GGML_OP_CLAMP) {
@@ -14835,6 +15255,13 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
14835
15255
  case GGML_UNARY_OP_RELU:
14836
15256
  tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
14837
15257
  break;
15258
+ case GGML_UNARY_OP_XIELU:
15259
+ tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0);
15260
+ ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1));
15261
+ ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2));
15262
+ ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3));
15263
+ ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4));
15264
+ break;
14838
15265
  case GGML_UNARY_OP_NEG:
14839
15266
  tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
14840
15267
  break;