@novastera-oss/llamarn 0.3.1 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (347) hide show
  1. package/README.md +86 -3
  2. package/RNLlamaCpp.podspec +1 -1
  3. package/android/CMakeLists.txt +11 -3
  4. package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
  5. package/android/src/main/cpp/include/llama.h +53 -114
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  13. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  20. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  21. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  22. package/cpp/LlamaCppModel.cpp +2 -10
  23. package/cpp/PureCppImpl.cpp +71 -4
  24. package/cpp/SystemUtils.cpp +3 -7
  25. package/cpp/build-info.cpp +2 -2
  26. package/cpp/llama.cpp/CMakeLists.txt +2 -0
  27. package/cpp/llama.cpp/CODEOWNERS +1 -1
  28. package/cpp/llama.cpp/Makefile +6 -1605
  29. package/cpp/llama.cpp/README.md +5 -1
  30. package/cpp/llama.cpp/common/arg.cpp +230 -51
  31. package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
  32. package/cpp/llama.cpp/common/chat.cpp +539 -8
  33. package/cpp/llama.cpp/common/chat.h +8 -1
  34. package/cpp/llama.cpp/common/common.cpp +60 -15
  35. package/cpp/llama.cpp/common/common.h +64 -15
  36. package/cpp/llama.cpp/common/speculative.cpp +135 -54
  37. package/cpp/llama.cpp/common/speculative.h +8 -1
  38. package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
  39. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
  40. package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
  41. package/cpp/llama.cpp/flake.nix +0 -5
  42. package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
  43. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
  44. package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
  45. package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
  46. package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
  47. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
  51. package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
  52. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
  54. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
  55. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
  57. package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  75. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
  76. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
  77. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
  80. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
  82. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
  83. package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
  84. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
  85. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  100. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
  101. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
  102. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
  103. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
  104. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
  105. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
  106. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
  107. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
  108. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
  110. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
  111. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
  112. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
  113. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
  114. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
  117. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
  118. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  119. package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  120. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
  121. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
  122. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
  123. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  124. package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  125. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
  126. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  127. package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  128. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
  132. package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
  134. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
  135. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
  136. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
  137. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  138. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  141. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  142. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
  143. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
  144. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  145. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
  147. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
  148. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
  149. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  150. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
  152. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  153. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  154. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  161. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  164. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  166. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
  167. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
  168. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
  169. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
  170. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
  173. package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
  174. package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
  175. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
  176. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  177. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
  178. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
  179. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
  180. package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
  181. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
  182. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
  183. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  184. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  185. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  186. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  187. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
  188. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  189. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
  190. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  191. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  192. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
  193. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
  194. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
  195. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
  196. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
  197. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
  198. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
  199. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
  200. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
  201. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  202. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
  203. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
  204. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
  205. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  206. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  207. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
  208. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  209. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
  210. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  211. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  212. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  213. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
  214. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  215. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
  216. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
  217. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
  218. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
  219. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
  221. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
  222. package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  223. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
  224. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
  225. package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
  226. package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
  227. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
  228. package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
  229. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
  230. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
  231. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
  232. package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
  233. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
  234. package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
  235. package/cpp/llama.cpp/include/llama.h +53 -114
  236. package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
  237. package/cpp/llama.cpp/models/templates/README.md +2 -1
  238. package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
  239. package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
  240. package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
  241. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
  242. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
  243. package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
  244. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  245. package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
  246. package/cpp/llama.cpp/src/llama-adapter.h +3 -0
  247. package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
  248. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  249. package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
  250. package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
  251. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  252. package/cpp/llama.cpp/src/llama-context.cpp +61 -252
  253. package/cpp/llama.cpp/src/llama-context.h +10 -15
  254. package/cpp/llama.cpp/src/llama-cparams.h +0 -1
  255. package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
  256. package/cpp/llama.cpp/src/llama-graph.h +90 -51
  257. package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
  258. package/cpp/llama.cpp/src/llama-hparams.h +21 -6
  259. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
  260. package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
  261. package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
  262. package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
  263. package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
  264. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
  265. package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
  266. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
  267. package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
  268. package/cpp/llama.cpp/src/llama-memory.h +13 -10
  269. package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
  270. package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
  271. package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
  272. package/cpp/llama.cpp/src/llama-model.h +28 -4
  273. package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
  274. package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
  275. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  276. package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
  277. package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
  278. package/cpp/rn-completion.cpp +3 -27
  279. package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
  280. package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
  281. package/ios/include/chat.h +8 -1
  282. package/ios/include/common/minja/chat-template.hpp +16 -7
  283. package/ios/include/common/minja/minja.hpp +47 -12
  284. package/ios/include/common.h +64 -15
  285. package/ios/include/llama.h +53 -114
  286. package/ios/include/speculative.h +8 -1
  287. package/ios/libs/llama.xcframework/Info.plist +18 -18
  288. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  289. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
  290. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  291. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
  292. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
  293. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  294. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  295. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
  296. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  297. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  298. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  299. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  300. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  301. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  302. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  303. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
  304. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
  305. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
  306. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
  307. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
  308. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
  309. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
  310. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  311. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
  312. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
  313. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
  314. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  315. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  316. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  317. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
  318. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  319. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
  320. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
  321. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  322. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  323. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
  324. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
  325. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  326. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  327. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  328. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  329. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  330. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
  331. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
  332. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
  333. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
  334. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  335. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  336. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
  337. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
  338. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
  339. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
  340. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
  341. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  342. package/lib/module/NativeRNLlamaCpp.js.map +1 -1
  343. package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
  344. package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
  345. package/package.json +1 -2
  346. package/src/NativeRNLlamaCpp.ts +7 -0
  347. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
@@ -102,7 +102,9 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
102
102
 
103
103
  struct ggml_backend_vk_context;
104
104
 
105
- #define MAX_PARAMETER_COUNT 8
105
+ #define MAX_PARAMETER_COUNT 12
106
+ // Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
107
+ #define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)
106
108
 
107
109
  struct vk_pipeline_struct {
108
110
  std::string name;
@@ -113,6 +115,8 @@ struct vk_pipeline_struct {
113
115
  uint32_t parameter_count;
114
116
  std::array<uint32_t, 3> wg_denoms;
115
117
  uint32_t align;
118
+ // true if fields have been set by ggml_vk_create_pipeline
119
+ bool initialized {};
116
120
  // set to true to request the pipeline is compiled after the dryrun
117
121
  bool needed {};
118
122
  // set to true when the shader has been compiled
@@ -222,21 +226,7 @@ enum vk_device_architecture {
222
226
  AMD_RDNA2,
223
227
  AMD_RDNA3,
224
228
  INTEL_XE2,
225
- };
226
-
227
- // HSK x HSV
228
- enum FaHeadSizes {
229
- FA_HEAD_SIZE_64,
230
- FA_HEAD_SIZE_80,
231
- FA_HEAD_SIZE_96,
232
- FA_HEAD_SIZE_112,
233
- FA_HEAD_SIZE_128,
234
- FA_HEAD_SIZE_192,
235
- FA_HEAD_SIZE_192_128,
236
- FA_HEAD_SIZE_256,
237
- FA_HEAD_SIZE_576_512,
238
- FA_HEAD_SIZE_UNSUPPORTED,
239
- FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
229
+ NVIDIA_PRE_TURING,
240
230
  };
241
231
 
242
232
  static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
@@ -315,10 +305,64 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
315
305
  // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
316
306
  return vk_device_architecture::INTEL_XE2;
317
307
  }
308
+ } else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {
309
+ const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
310
+
311
+ bool cooperative_matrix = false;
312
+
313
+ // Detect "pre-turing" based on lack of coopmat support.
314
+ for (const auto& properties : ext_props) {
315
+ if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
316
+ cooperative_matrix = true;
317
+ break;
318
+ }
319
+ }
320
+
321
+ if (!cooperative_matrix) {
322
+ return vk_device_architecture::NVIDIA_PRE_TURING;
323
+ }
318
324
  }
319
325
  return vk_device_architecture::OTHER;
320
326
  }
321
327
 
328
+ enum vk_conv_shapes {
329
+ CONV_SHAPE_128x128,
330
+ CONV_SHAPE_64x32,
331
+ CONV_SHAPE_32x256,
332
+ CONV_SHAPE_COUNT,
333
+ };
334
+
335
+ enum dmmv_wg_sizes {
336
+ DMMV_WG_SIZE_SUBGROUP,
337
+ DMMV_WG_SIZE_LARGE,
338
+ DMMV_WG_SIZE_COUNT,
339
+ };
340
+
341
+ enum FaCodePath {
342
+ FA_SCALAR,
343
+ FA_COOPMAT1,
344
+ FA_COOPMAT2,
345
+ };
346
+
347
+ struct vk_fa_pipeline_state {
348
+ vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc)
349
+ : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {}
350
+
351
+ uint32_t HSK, HSV;
352
+ bool small_rows;
353
+ FaCodePath path;
354
+ bool aligned;
355
+ bool f32acc;
356
+
357
+ bool operator<(const vk_fa_pipeline_state &b) const {
358
+ return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) <
359
+ std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc);
360
+ }
361
+ };
362
+
363
+ static constexpr uint32_t num_argsort_pipelines = 11;
364
+ static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
365
+
322
366
  struct vk_device_struct {
323
367
  std::recursive_mutex mutex;
324
368
 
@@ -344,6 +388,11 @@ struct vk_device_struct {
344
388
  bool float_controls_rte_fp16;
345
389
  bool subgroup_add;
346
390
  bool subgroup_shuffle;
391
+ bool subgroup_ballot;
392
+ bool multi_add;
393
+
394
+ bool add_rms_fusion;
395
+ uint32_t partials_binding_alignment;
347
396
 
348
397
  bool integer_dot_product;
349
398
 
@@ -405,8 +454,8 @@ struct vk_device_struct {
405
454
  vk_pipeline pipeline_quantize_q8_1;
406
455
 
407
456
  vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
408
- vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
409
- vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
457
+ vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
458
+ vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
410
459
  vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
411
460
 
412
461
  vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
@@ -424,11 +473,20 @@ struct vk_device_struct {
424
473
  vk_pipeline pipeline_mul_norepeat[2][2][2];
425
474
  vk_pipeline pipeline_div[2][2][2];
426
475
  vk_pipeline pipeline_div_norepeat[2][2][2];
476
+ vk_pipeline pipeline_add_rms[2][2][2];
477
+ vk_pipeline pipeline_add_rms_norepeat[2][2][2];
478
+
479
+ // indexed by num_additional_fused_ops == num_adds - 1
480
+ vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
481
+ vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS];
482
+
483
+ vk_pipeline pipeline_add_id_f32;
427
484
 
428
485
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
429
486
  vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
430
487
  vk_pipeline pipeline_scale_f32;
431
488
  vk_pipeline pipeline_sqr_f32;
489
+ vk_pipeline pipeline_sqrt_f32;
432
490
  vk_pipeline pipeline_sin_f32;
433
491
  vk_pipeline pipeline_cos_f32;
434
492
  vk_pipeline pipeline_clamp_f32;
@@ -444,10 +502,13 @@ struct vk_device_struct {
444
502
  vk_pipeline pipeline_group_norm_f32;
445
503
  vk_pipeline pipeline_rms_norm_f32;
446
504
  vk_pipeline pipeline_rms_norm_mul_f32;
505
+ vk_pipeline pipeline_rms_norm_partials_f32;
506
+ vk_pipeline pipeline_rms_norm_mul_partials_f32;
447
507
  vk_pipeline pipeline_rms_norm_back_f32;
448
508
  vk_pipeline pipeline_l2_norm_f32;
449
509
 
450
510
  // [src/dst 0=fp32,1=fp16]
511
+ vk_pipeline pipeline_exp[2];
451
512
  vk_pipeline pipeline_gelu[2];
452
513
  vk_pipeline pipeline_gelu_erf[2];
453
514
  vk_pipeline pipeline_gelu_quick[2];
@@ -459,6 +520,7 @@ struct vk_device_struct {
459
520
  vk_pipeline pipeline_geglu[2];
460
521
  vk_pipeline pipeline_reglu[2];
461
522
  vk_pipeline pipeline_swiglu[2];
523
+ vk_pipeline pipeline_swiglu_oai[2];
462
524
  vk_pipeline pipeline_geglu_erf[2];
463
525
  vk_pipeline pipeline_geglu_quick[2];
464
526
 
@@ -472,7 +534,7 @@ struct vk_device_struct {
472
534
  vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
473
535
  vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
474
536
  vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
475
- vk_pipeline pipeline_argsort_f32;
537
+ vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
476
538
  vk_pipeline pipeline_sum_rows_f32;
477
539
  vk_pipeline pipeline_argmax_f32;
478
540
  vk_pipeline pipeline_count_equal_i32;
@@ -483,20 +545,17 @@ struct vk_device_struct {
483
545
  vk_pipeline pipeline_rwkv_wkv6_f32;
484
546
  vk_pipeline pipeline_rwkv_wkv7_f32;
485
547
  vk_pipeline pipeline_opt_step_adamw_f32;
486
- vk_pipeline pipeline_conv2d_f32;
487
- vk_pipeline pipeline_conv2d_dw_whcn_f32;
488
- vk_pipeline pipeline_conv2d_dw_cwhn_f32;
489
-
490
- // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
491
- vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
548
+ vk_pipeline pipeline_opt_step_sgd_f32;
549
+ vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
550
+ vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
551
+ vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
552
+ vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
492
553
 
493
- vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
494
-
495
- vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
554
+ std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
496
555
 
497
556
  vk_pipeline pipeline_flash_attn_split_k_reduce;
498
557
 
499
- std::unordered_map<std::string, vk_pipeline_ref> pipelines;
558
+ std::vector<vk_pipeline_ref> all_pipelines;
500
559
 
501
560
  std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
502
561
 
@@ -506,6 +565,7 @@ struct vk_device_struct {
506
565
  ggml_backend_buffer_type buffer_type;
507
566
 
508
567
  bool disable_fusion;
568
+ bool disable_host_visible_vidmem;
509
569
 
510
570
  #ifdef GGML_VULKAN_MEMORY_DEBUG
511
571
  std::unique_ptr<vk_memory_logger> memory_logger;
@@ -526,15 +586,15 @@ struct vk_device_struct {
526
586
  compute_queue.cmd_pool.destroy(device);
527
587
  transfer_queue.cmd_pool.destroy(device);
528
588
 
529
- for (auto& pipeline : pipelines) {
530
- if (pipeline.second.expired()) {
589
+ for (auto& pipeline : all_pipelines) {
590
+ if (pipeline.expired()) {
531
591
  continue;
532
592
  }
533
593
 
534
- vk_pipeline pl = pipeline.second.lock();
594
+ vk_pipeline pl = pipeline.lock();
535
595
  ggml_vk_destroy_pipeline(device, pl);
536
596
  }
537
- pipelines.clear();
597
+ all_pipelines.clear();
538
598
 
539
599
  device.destroyDescriptorSetLayout(dsl);
540
600
 
@@ -680,6 +740,8 @@ struct vk_op_glu_push_constants {
680
740
  uint32_t ne00;
681
741
  uint32_t ne20;
682
742
  uint32_t mode; // 0: default, 1: swapped, 2: split
743
+ float alpha; // for swiglu_oai
744
+ float limit;
683
745
  };
684
746
 
685
747
  struct vk_op_unary_push_constants {
@@ -769,6 +831,28 @@ struct vk_op_binary_push_constants {
769
831
  float param1; float param2; int32_t param3;
770
832
  };
771
833
 
834
+ struct vk_op_multi_add_push_constants {
835
+ // shape for dst
836
+ uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
837
+
838
+ // strides for srcs+dst
839
+ uint32_t nb[MAX_PARAMETER_COUNT][4];
840
+
841
+ uint32_t rms_partials;
842
+ };
843
+ // update multi_add.comp if this changes
844
+ static_assert(MAX_PARAMETER_COUNT == 12);
845
+ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
846
+
847
+ struct vk_op_add_id_push_constants {
848
+ uint32_t ne0;
849
+ uint32_t ne1;
850
+ uint32_t s01;
851
+ uint32_t s02;
852
+ uint32_t s11;
853
+ uint32_t s21;
854
+ };
855
+
772
856
  struct vk_op_diag_mask_push_constants {
773
857
  uint32_t ncols;
774
858
  uint32_t rows_per_channel;
@@ -810,11 +894,11 @@ struct vk_op_soft_max_push_constants {
810
894
  float m1;
811
895
  uint32_t n_head_log2;
812
896
  uint32_t nrows_x;
897
+ uint32_t has_sinks;
813
898
  };
814
899
 
815
900
  struct vk_op_argsort_push_constants {
816
901
  uint32_t ncols;
817
- uint32_t ncols_pad;
818
902
  int32_t order;
819
903
  };
820
904
 
@@ -907,8 +991,22 @@ struct vk_op_conv2d_push_constants {
907
991
  uint32_t nb1;
908
992
  uint32_t nb2;
909
993
  uint32_t nb3;
994
+
995
+ // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH
996
+ uint32_t KWmp; uint32_t KWL;
997
+ uint32_t KWKHmp; uint32_t KWKHL;
998
+ uint32_t OWmp; uint32_t OWL;
999
+ uint32_t OWOHmp; uint32_t OWOHL;
910
1000
  };
911
1001
 
1002
+ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
1003
+ // Compute magic values to divide by KW, KW*KH, OW, OW*OH
1004
+ init_fastdiv_values(p.KW, p.KWmp, p.KWL);
1005
+ init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
1006
+ init_fastdiv_values(p.OW, p.OWmp, p.OWL);
1007
+ init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
1008
+ }
1009
+
912
1010
  struct vk_op_conv2d_dw_push_constants {
913
1011
  uint32_t ne;
914
1012
  uint32_t batches;
@@ -935,6 +1033,39 @@ struct vk_op_upscale_push_constants {
935
1033
  float sf0; float sf1; float sf2; float sf3;
936
1034
  };
937
1035
 
1036
+ struct vk_op_sum_rows_push_constants
1037
+ {
1038
+ uint32_t n_cols;
1039
+ uint32_t ne01, ne02;
1040
+ uint32_t nb01, nb02, nb03;
1041
+ uint32_t nb11, nb12, nb13;
1042
+ float weight;
1043
+ uint32_t misalign_offsets;
1044
+ uint32_t ne0_12mp, ne0_12L;
1045
+ uint32_t ne0_1mp, ne0_1L;
1046
+ };
1047
+
1048
+ static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
1049
+ uint32_t type_size = (uint32_t)ggml_type_size(src->type);
1050
+ vk_op_sum_rows_push_constants p = {};
1051
+ p.n_cols = (uint32_t)n_cols;
1052
+ p.ne01 = (uint32_t)src->ne[1];
1053
+ p.ne02 = (uint32_t)src->ne[2];
1054
+ p.nb01 = (uint32_t)src->nb[1] / type_size;
1055
+ p.nb02 = (uint32_t)src->nb[2] / type_size;
1056
+ p.nb03 = (uint32_t)src->nb[3] / type_size;
1057
+ p.nb11 = (uint32_t)dst->nb[1] / type_size;
1058
+ p.nb12 = (uint32_t)dst->nb[2] / type_size;
1059
+ p.nb13 = (uint32_t)dst->nb[3] / type_size;
1060
+ p.weight = 1.0f;
1061
+ return p;
1062
+ }
1063
+
1064
+ template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {
1065
+ init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L);
1066
+ init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L);
1067
+ }
1068
+
938
1069
  // Allow pre-recording command buffers
939
1070
  struct vk_staging_memcpy {
940
1071
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1055,17 +1186,23 @@ class vk_perf_logger {
1055
1186
  return;
1056
1187
  }
1057
1188
  if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
1058
- const uint64_t m = node->src[0]->ne[1];
1059
- const uint64_t n = node->src[1]->ne[1];
1060
- const uint64_t k = node->src[1]->ne[0];
1061
- std::string name = ggml_op_name(node->op);
1062
- if (n == 1) {
1063
- name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
1064
- } else {
1065
- name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
1189
+ const uint64_t m = node->src[0]->ne[1];
1190
+ const uint64_t n = node->ne[1];
1191
+ const uint64_t k = node->src[1]->ne[0];
1192
+ const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
1193
+ std::string name = ggml_op_name(node->op);
1194
+ if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) ||
1195
+ (node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) {
1196
+ name += "_VEC";
1197
+ }
1198
+ name += " ";
1199
+ name += ggml_type_name(node->src[0]->type);
1200
+ name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
1201
+ if (batch > 1) {
1202
+ name += " batch=" + std::to_string(batch);
1066
1203
  }
1067
1204
  timings[name].push_back(time);
1068
- flops[name].push_back(m * n * (k + (k - 1)));
1205
+ flops[name].push_back(m * n * (k + (k - 1)) * batch);
1069
1206
  return;
1070
1207
  }
1071
1208
  if (node->op == GGML_OP_CONV_2D) {
@@ -1089,6 +1226,12 @@ class vk_perf_logger {
1089
1226
  timings[name].push_back(time);
1090
1227
  return;
1091
1228
  }
1229
+ if (node->op == GGML_OP_RMS_NORM) {
1230
+ std::string name = ggml_op_name(node->op);
1231
+ name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
1232
+ timings[name].push_back(time);
1233
+ return;
1234
+ }
1092
1235
  timings[ggml_op_name(node->op)].push_back(time);
1093
1236
  }
1094
1237
  private:
@@ -1103,10 +1246,25 @@ struct ggml_backend_vk_context {
1103
1246
 
1104
1247
  size_t semaphore_idx, event_idx;
1105
1248
  ggml_vk_garbage_collector gc;
1106
- size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
1107
- vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
1249
+ size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset;
1250
+ vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials;
1108
1251
  vk::Fence fence, almost_ready_fence;
1109
1252
  bool almost_ready_fence_pending {};
1253
+ // Set before op_add and unset after op_rms_norm to indicate that the add should
1254
+ // write partial sums to accumulate the square of the vector components
1255
+ bool do_add_rms_partials;
1256
+
1257
+ // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
1258
+ vk_pipeline_struct * prealloc_y_last_pipeline_used {};
1259
+ const ggml_tensor * prealloc_y_last_tensor_used {};
1260
+
1261
+ // Track which nodes have been used since the last sync, and whether they were written to
1262
+ std::vector<const ggml_tensor *> unsynced_nodes_written;
1263
+ std::vector<const ggml_tensor *> unsynced_nodes_read;
1264
+ // Track which prealloc buffers have pending reads that need to be synchronized.
1265
+ // These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set),
1266
+ // and set to true after the buffer contents are consumed.
1267
+ bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
1110
1268
 
1111
1269
  vk_buffer buffer_pool[MAX_VK_BUFFERS];
1112
1270
 
@@ -1340,13 +1498,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
1340
1498
  vk::DebugUtilsObjectNameInfoEXT duoni;
1341
1499
  duoni.objectType = vk::ObjectType::ePipeline;
1342
1500
  duoni.pObjectName = pipeline->name.c_str();
1343
- duoni.objectHandle = reinterpret_cast<uint64_t>(static_cast<VkPipeline_T*>(pipeline->pipeline));
1501
+ duoni.objectHandle = /*reinterpret_cast*/(uint64_t)(static_cast<VkPipeline>(pipeline->pipeline));
1344
1502
  vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
1345
1503
  }
1346
1504
 
1347
1505
  {
1348
1506
  std::lock_guard<std::recursive_mutex> guard(device->mutex);
1349
- device->pipelines.insert({ pipeline->name, pipeline });
1507
+ device->all_pipelines.push_back(pipeline);
1350
1508
  }
1351
1509
 
1352
1510
  {
@@ -1750,6 +1908,8 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
1750
1908
  } else if (device->uma) {
1751
1909
  // Fall back to host memory type
1752
1910
  buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
1911
+ } else if (device->disable_host_visible_vidmem) {
1912
+ buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eDeviceLocal);
1753
1913
  } else {
1754
1914
  // use rebar if available, otherwise fallback to device only visible memory
1755
1915
  buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -1781,14 +1941,18 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) {
1781
1941
  return { buf, 0, VK_WHOLE_SIZE };
1782
1942
  }
1783
1943
 
1784
- static void ggml_vk_sync_buffers(vk_context& ctx) {
1944
+ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) {
1785
1945
  VK_LOG_DEBUG("ggml_vk_sync_buffers()");
1786
1946
 
1787
- const bool transfer_queue = ctx->p->q->transfer_only;
1947
+ const bool transfer_queue = subctx->p->q->transfer_only;
1788
1948
 
1789
- ctx->s->buffer.pipelineBarrier(
1790
- ctx->p->q->stage_flags,
1791
- ctx->p->q->stage_flags,
1949
+ if (ctx) {
1950
+ ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
1951
+ }
1952
+
1953
+ subctx->s->buffer.pipelineBarrier(
1954
+ subctx->p->q->stage_flags,
1955
+ subctx->p->q->stage_flags,
1792
1956
  {},
1793
1957
  { {
1794
1958
  { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },
@@ -1815,47 +1979,12 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
1815
1979
  );
1816
1980
  }
1817
1981
 
1818
- enum FaCodePath {
1819
- FA_SCALAR,
1820
- FA_COOPMAT1,
1821
- FA_COOPMAT2,
1822
- };
1823
-
1824
- static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
1825
- if (hsk != 192 && hsk != 576 && hsk != hsv) {
1826
- return FA_HEAD_SIZE_UNSUPPORTED;
1827
- }
1828
- switch (hsk) {
1829
- case 64: return FA_HEAD_SIZE_64;
1830
- case 80: return FA_HEAD_SIZE_80;
1831
- case 96: return FA_HEAD_SIZE_96;
1832
- case 112: return FA_HEAD_SIZE_112;
1833
- case 128: return FA_HEAD_SIZE_128;
1834
- case 192:
1835
- if (hsv == 192) {
1836
- return FA_HEAD_SIZE_192;
1837
- } else if (hsv == 128) {
1838
- return FA_HEAD_SIZE_192_128;
1839
- } else {
1840
- return FA_HEAD_SIZE_UNSUPPORTED;
1841
- }
1842
- case 256: return FA_HEAD_SIZE_256;
1843
- case 576:
1844
- if (hsv == 512) {
1845
- return FA_HEAD_SIZE_576_512;
1846
- } else {
1847
- return FA_HEAD_SIZE_UNSUPPORTED;
1848
- }
1849
- default: return FA_HEAD_SIZE_UNSUPPORTED;
1850
- }
1851
- }
1852
-
1853
1982
  // number of rows/cols for flash attention shader
1854
1983
  static constexpr uint32_t flash_attention_num_small_rows = 32;
1855
1984
  static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
1856
1985
 
1857
1986
  static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
1858
- if (hsv >= 512) {
1987
+ if (hsv >= 192) {
1859
1988
  return 2;
1860
1989
  } else {
1861
1990
  return 8;
@@ -1885,7 +2014,13 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
1885
2014
  if (small_rows) {
1886
2015
  return {scalar_flash_attention_num_small_rows, 64};
1887
2016
  } else {
1888
- return {get_fa_scalar_num_large_rows(hsv), 32};
2017
+ if ((hsv | hsk) & 8) {
2018
+ // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
2019
+ // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
2020
+ return {get_fa_scalar_num_large_rows(hsv), 64};
2021
+ } else {
2022
+ return {get_fa_scalar_num_large_rows(hsv), 32};
2023
+ }
1889
2024
  }
1890
2025
  }
1891
2026
 
@@ -1903,8 +2038,8 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
1903
2038
  }
1904
2039
 
1905
2040
  // small cols to reduce register count
1906
- if (ggml_is_quantized(type) || hsk >= 256) {
1907
- if (hsk >= 512) {
2041
+ if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) {
2042
+ if (hsk >= 512 || hsv >= 512) {
1908
2043
  return {32, 32};
1909
2044
  } else {
1910
2045
  return {64, 32};
@@ -1913,6 +2048,10 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
1913
2048
  return {64, 64};
1914
2049
  }
1915
2050
 
2051
+ static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
2052
+ return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
2053
+ }
2054
+
1916
2055
  static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
1917
2056
 
1918
2057
  uint32_t lut_size = 0;
@@ -1938,6 +2077,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1938
2077
  break;
1939
2078
  case GGML_TYPE_IQ4_NL:
1940
2079
  case GGML_TYPE_IQ4_XS:
2080
+ case GGML_TYPE_MXFP4:
1941
2081
  lut_size = 4*16;
1942
2082
  break;
1943
2083
  default:
@@ -1950,10 +2090,11 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1950
2090
  const uint32_t warps = warptile[0] / warptile[10];
1951
2091
 
1952
2092
  const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
1953
- const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
2093
+ const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
1954
2094
  const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
2095
+ const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
1955
2096
 
1956
- const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
2097
+ const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
1957
2098
  const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
1958
2099
 
1959
2100
  VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
@@ -2037,8 +2178,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
2037
2178
  const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
2038
2179
  const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
2039
2180
 
2181
+ const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
2182
+ const uint32_t mul_mat_subgroup_size_8 = std::max(mul_mat_subgroup_size, 8u);
2183
+ const uint32_t mul_mat_subgroup_size_16 = std::max(mul_mat_subgroup_size, 16u);
2184
+ const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u);
2185
+
2186
+ const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) ||
2187
+ (device->subgroup_size_control && device->subgroup_max_size >= 16);
2188
+
2040
2189
  // mulmat
2041
2190
  std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
2191
+ l_warptile_id, m_warptile_id, s_warptile_id,
2042
2192
  l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
2043
2193
  l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
2044
2194
  l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
@@ -2067,17 +2217,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
2067
2217
  s_mmq_wg_denoms = { 32, 64, 1 };
2068
2218
 
2069
2219
  // spec constants and tile sizes for quant matmul (Qi_K)
2070
- l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
2071
- m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
2072
- s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
2073
- l_mmq_wg_denoms_k = { 64, 128, 1 };
2074
- m_mmq_wg_denoms_k = { 32, 64, 1 };
2075
- s_mmq_wg_denoms_k = { 32, 32, 1 };
2220
+ l_warptile_mmq_k = { 256, 128, 256, 64, 1 };
2221
+ m_warptile_mmq_k = { 256, 128, 128, 64, 1 };
2222
+ s_warptile_mmq_k = { 256, 32, 64, 128, 0 };
2223
+ l_mmq_wg_denoms_k = { 128, 256, 1 };
2224
+ m_mmq_wg_denoms_k = { 128, 128, 1 };
2225
+ s_mmq_wg_denoms_k = { 32, 64, 1 };
2076
2226
 
2077
2227
  // spec constants and tile sizes for quant matmul_id
2078
- l_warptile_mmqid = { 256, 128, 128, 16, 0 };
2079
- m_warptile_mmqid = { 256, 128, 64, 16, 0 };
2080
- s_warptile_mmqid = { 256, 128, 64, 16, 0 };
2228
+ l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
2229
+ m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
2230
+ s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
2081
2231
  l_mmqid_wg_denoms = { 128, 128, 1 };
2082
2232
  m_mmqid_wg_denoms = { 128, 64, 1 };
2083
2233
  s_mmqid_wg_denoms = { 128, 64, 1 };
@@ -2109,9 +2259,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
2109
2259
  m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
2110
2260
  s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
2111
2261
 
2262
+ l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
2263
+ m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
2264
+ s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
2265
+
2266
+ l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
2267
+ m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
2268
+ s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
2269
+
2112
2270
  // chip specific tuning
2113
2271
  if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
2114
2272
  m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
2273
+ m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
2115
2274
  }
2116
2275
 
2117
2276
  l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
@@ -2137,14 +2296,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
2137
2296
  }
2138
2297
 
2139
2298
  // Disable mul_mat_id if not enough shared memory is available
2140
- if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) {
2299
+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) {
2141
2300
  device->mul_mat_id_s[i] = false;
2142
2301
  device->mul_mat_id_m[i] = false;
2143
2302
  device->mul_mat_id_l[i] = false;
2144
- } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) {
2303
+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) {
2145
2304
  device->mul_mat_id_m[i] = false;
2146
2305
  device->mul_mat_id_l[i] = false;
2147
- } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) {
2306
+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
2148
2307
  device->mul_mat_id_l[i] = false;
2149
2308
  }
2150
2309
  }
@@ -2177,11 +2336,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
2177
2336
 
2178
2337
  if (!pipeline) {
2179
2338
  pipeline = std::make_shared<vk_pipeline_struct>();
2339
+ }
2340
+ if (!pipeline->initialized) {
2180
2341
  pipeline->name = name;
2181
2342
  pipeline->parameter_count = parameter_count;
2182
2343
  pipeline->push_constant_size = push_constant_size;
2183
2344
  pipeline->wg_denoms = wg_denoms;
2184
2345
  pipeline->align = align;
2346
+ pipeline->initialized = true;
2185
2347
  }
2186
2348
 
2187
2349
  if (!pipeline->needed || pipeline->compiled) {
@@ -2227,26 +2389,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
2227
2389
  return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
2228
2390
  };
2229
2391
 
2230
- #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
2231
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2232
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2233
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2234
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2235
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2236
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2237
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2238
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2239
-
2240
2392
  #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
2241
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
2242
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
2243
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
2244
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
2245
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
2246
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
2247
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
2248
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
2249
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
2393
+ for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
2394
+ uint32_t HSK = fa.first.HSK; \
2395
+ uint32_t HSV = fa.first.HSV; \
2396
+ bool small_rows = fa.first.small_rows; \
2397
+ FaCodePath path = fa.first.path; \
2398
+ bool aligned = fa.first.aligned; \
2399
+ bool f32acc = fa.first.f32acc; \
2400
+ if (path == FAPATH) { \
2401
+ if (aligned) { \
2402
+ if (f32acc) { \
2403
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2404
+ } else { \
2405
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2406
+ } \
2407
+ } else { \
2408
+ if (f32acc) { \
2409
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2410
+ } else { \
2411
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2412
+ } \
2413
+ } \
2414
+ } \
2415
+ }
2250
2416
 
2251
2417
  CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
2252
2418
  CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
@@ -2269,7 +2435,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
2269
2435
  CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
2270
2436
  }
2271
2437
  #endif
2272
- #undef CREATE_FA2
2273
2438
  #undef CREATE_FA
2274
2439
 
2275
2440
  #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -2314,32 +2479,36 @@ static void ggml_vk_load_shaders(vk_device& device) {
2314
2479
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2315
2480
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2316
2481
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2482
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2483
+
2484
+ GGML_ASSERT(device->subgroup_ballot);
2317
2485
 
2318
- CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
2486
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
2319
2487
  #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2320
2488
  if (device->coopmat_bf16_support) {
2321
- CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
2489
+ CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
2322
2490
  }
2323
2491
  #endif
2324
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2325
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2326
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2327
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2328
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2329
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2330
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2331
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2332
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2333
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2334
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2335
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2336
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2337
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2338
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2339
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2340
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2341
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2342
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2492
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2493
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2494
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2495
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2496
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2497
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2498
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2499
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2500
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2501
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2502
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2503
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2504
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2505
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2506
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2507
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2508
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2509
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2510
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2511
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
2343
2512
  #undef CREATE_MM
2344
2513
  #undef CREATE_MM2
2345
2514
  } else
@@ -2401,6 +2570,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2401
2570
  CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2402
2571
  CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2403
2572
  CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2573
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2404
2574
  } else {
2405
2575
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2406
2576
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2422,79 +2592,59 @@ static void ggml_vk_load_shaders(vk_device& device) {
2422
2592
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2423
2593
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2424
2594
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2595
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2425
2596
  }
2426
2597
 
2427
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2428
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2429
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2598
+ GGML_ASSERT(device->subgroup_ballot);
2599
+
2600
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2601
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2602
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2430
2603
  #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2431
2604
  if (device->coopmat_bf16_support) {
2432
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2605
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2433
2606
  }
2434
2607
  #endif
2435
2608
 
2436
- if (device->coopmat_acc_f16_support) {
2437
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2438
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2439
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2440
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2441
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2442
-
2443
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2444
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2445
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2446
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2447
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2448
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2449
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2450
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2451
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2452
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2453
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2454
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2455
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2456
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2457
- } else {
2458
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2459
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2460
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2461
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2462
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2463
-
2464
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2465
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2466
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2467
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2468
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2469
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2470
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2471
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2472
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2473
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2474
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2475
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2476
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2477
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2478
- }
2609
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2610
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2611
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2612
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2613
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2614
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2615
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2616
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2617
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2618
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2619
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2620
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2621
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2622
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2623
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2624
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2625
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2626
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2627
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2628
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2479
2629
  #undef CREATE_MM2
2480
2630
  #undef CREATE_MM
2481
2631
  } else
2482
2632
  #endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
2483
2633
  if (device->fp16) {
2484
2634
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
2485
- #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2635
+ #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
2486
2636
  if (device->mul_mat ## ID ## _l[TYPE]) \
2487
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2637
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
2488
2638
  if (device->mul_mat ## ID ## _m[TYPE]) \
2489
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2639
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
2490
2640
  if (device->mul_mat ## ID ## _s[TYPE]) \
2491
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2641
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
2492
2642
  if (device->mul_mat ## ID ## _l[TYPE]) \
2493
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
2643
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
2494
2644
  if (device->mul_mat ## ID ## _m[TYPE]) \
2495
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
2645
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
2496
2646
  if (device->mul_mat ## ID ## _s[TYPE]) \
2497
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2647
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
2498
2648
 
2499
2649
  #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2500
2650
  if (device->mul_mat ## ID ## _l[TYPE]) { \
@@ -2511,37 +2661,38 @@ static void ggml_vk_load_shaders(vk_device& device) {
2511
2661
  } \
2512
2662
 
2513
2663
  // Create 2 variants, {f16,f32} accumulator
2514
- #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2515
- CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2516
- CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2517
-
2518
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2519
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2520
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2521
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2522
-
2523
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2524
-
2525
- CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2526
- CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2527
- CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2528
- CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2529
- CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2530
-
2531
- CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2532
- CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2533
- CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2534
- CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2535
- CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2536
- CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2537
- CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2538
- CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2539
- CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2540
- CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2541
- CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2542
- CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2543
- CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2544
- CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2664
+ #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
2665
+ CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
2666
+ CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
2667
+
2668
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
2669
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
2670
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
2671
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
2672
+
2673
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
2674
+
2675
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2676
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2677
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2678
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2679
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2680
+
2681
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2682
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2683
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2684
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2685
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2686
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2687
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2688
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2689
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2690
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2691
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2692
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2693
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2694
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2695
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2545
2696
 
2546
2697
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2547
2698
  if (device->integer_dot_product) {
@@ -2553,50 +2704,77 @@ static void ggml_vk_load_shaders(vk_device& device) {
2553
2704
  }
2554
2705
  #endif
2555
2706
 
2556
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2557
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2558
- CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2559
-
2560
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2561
-
2562
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2563
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2564
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2565
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2566
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2567
-
2568
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2569
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2570
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2571
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2572
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2573
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2574
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2575
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2576
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2577
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2578
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2579
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2580
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2581
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2707
+ if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
2708
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
2709
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
2710
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
2711
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
2712
+
2713
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2714
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2715
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2716
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2717
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2718
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2719
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2720
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2721
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2722
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2723
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2724
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2725
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2726
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2727
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2728
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2729
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2730
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2731
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2732
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2733
+ } else {
2734
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
2735
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
2736
+ CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
2737
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
2738
+
2739
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2740
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2741
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2742
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2743
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2744
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2745
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2746
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2747
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2748
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2749
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2750
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2751
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2752
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2753
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2754
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2755
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2756
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2757
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2758
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2759
+ }
2582
2760
  #undef CREATE_MM2
2583
2761
  #undef CREATE_MMQ
2584
2762
  #undef CREATE_MM
2585
2763
  } else {
2586
2764
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
2587
- #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2765
+ #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
2588
2766
  if (device->mul_mat ## ID ## _l[TYPE]) \
2589
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2767
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, REQSUBGROUPSIZE > 0, false, REQSUBGROUPSIZE); \
2590
2768
  if (device->mul_mat ## ID ## _m[TYPE]) \
2591
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2769
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, REQSUBGROUPSIZE > 0, false, REQSUBGROUPSIZE); \
2592
2770
  if (device->mul_mat ## ID ## _s[TYPE]) \
2593
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2771
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, REQSUBGROUPSIZE > 0, false, REQSUBGROUPSIZE); \
2594
2772
  if (device->mul_mat ## ID ## _l[TYPE]) \
2595
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
2773
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
2596
2774
  if (device->mul_mat ## ID ## _m[TYPE]) \
2597
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
2775
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
2598
2776
  if (device->mul_mat ## ID ## _s[TYPE]) \
2599
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2777
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
2600
2778
 
2601
2779
  #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2602
2780
  if (device->mul_mat ## ID ## _l[TYPE]) \
@@ -2606,33 +2784,34 @@ static void ggml_vk_load_shaders(vk_device& device) {
2606
2784
  if (device->mul_mat ## ID ## _s[TYPE]) \
2607
2785
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2608
2786
 
2609
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2610
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2611
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2612
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2613
-
2614
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2615
-
2616
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2617
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2618
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2619
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2620
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2621
-
2622
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2623
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2624
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2625
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2626
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2627
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2628
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2629
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2630
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2631
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2632
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2633
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2634
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2635
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2787
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
2788
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
2789
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
2790
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
2791
+
2792
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
2793
+
2794
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2795
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2796
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2797
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2798
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2799
+
2800
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2801
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2802
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2803
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2804
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2805
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2806
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2807
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2808
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2809
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2810
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2811
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2812
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2813
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2814
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
2636
2815
 
2637
2816
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2638
2817
  if (device->integer_dot_product) {
@@ -2644,32 +2823,59 @@ static void ggml_vk_load_shaders(vk_device& device) {
2644
2823
  }
2645
2824
  #endif
2646
2825
 
2647
- CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2648
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2649
- CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2650
-
2651
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2652
-
2653
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2654
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2655
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2656
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2657
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2658
-
2659
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2660
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2661
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2662
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2663
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2664
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2665
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2666
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2667
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2668
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2669
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2670
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2671
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2672
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2826
+ if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
2827
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
2828
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
2829
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
2830
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
2831
+
2832
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2833
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2834
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2835
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2836
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2837
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2838
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2839
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2840
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2841
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2842
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2843
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2844
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2845
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2846
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2847
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2848
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2849
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2850
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2851
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
2852
+ } else {
2853
+ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
2854
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
2855
+ CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
2856
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
2857
+
2858
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2859
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2860
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2861
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2862
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2863
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2864
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2865
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2866
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2867
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2868
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2869
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2870
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2871
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2872
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2873
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2874
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2875
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2876
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2877
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
2878
+ }
2673
2879
  }
2674
2880
  // reusing CREATE_MM from the fp32 path
2675
2881
  if ((device->coopmat2 || device->coopmat_support)
@@ -2686,8 +2892,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2686
2892
  m_wg_denoms = { 64, 64, 1 };
2687
2893
  s_wg_denoms = { 32, 32, 1 };
2688
2894
 
2689
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2690
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2895
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
2896
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
2691
2897
  }
2692
2898
  #undef CREATE_MM
2693
2899
 
@@ -2705,52 +2911,61 @@ static void ggml_vk_load_shaders(vk_device& device) {
2705
2911
  rm_stdq = 2;
2706
2912
  uint32_t rm_iq = 2 * rm_kq;
2707
2913
 
2708
- for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
2709
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2710
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2711
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2712
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2713
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2714
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2715
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2716
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
2717
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2718
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2719
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2720
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2721
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2722
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2723
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2724
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2725
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2726
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2727
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2728
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2729
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2730
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2731
-
2732
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2733
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2734
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2735
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2736
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2737
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2738
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2739
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
2740
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2741
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2742
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2743
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2744
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2745
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2746
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2747
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2748
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2749
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2750
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2751
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2752
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2753
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2914
+ for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) {
2915
+ uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_16 : (subgroup_size_16 * 4);
2916
+ uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? device->subgroup_size : (device->subgroup_size * 4);
2917
+
2918
+ const bool s = device->subgroup_add && device->architecture != vk_device_architecture::AMD_GCN;
2919
+
2920
+ for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
2921
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[s], arr_dmmv_f32_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
2922
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[s], arr_dmmv_f16_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
2923
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[s], arr_dmmv_bf16_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
2924
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[s], arr_dmmv_q4_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
2925
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[s], arr_dmmv_q4_1_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
2926
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[s], arr_dmmv_q5_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
2927
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[s], arr_dmmv_q5_1_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
2928
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[s], arr_dmmv_q8_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true);
2929
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[s], arr_dmmv_q2_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
2930
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[s], arr_dmmv_q3_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
2931
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[s], arr_dmmv_q4_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
2932
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32", arr_dmmv_q5_k_f32_f32_len[s], arr_dmmv_q5_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
2933
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32", arr_dmmv_q6_k_f32_f32_len[s], arr_dmmv_q6_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
2934
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32", arr_dmmv_iq1_s_f32_f32_len[s], arr_dmmv_iq1_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2935
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32", arr_dmmv_iq1_m_f32_f32_len[s], arr_dmmv_iq1_m_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2936
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32", arr_dmmv_iq2_xxs_f32_f32_len[s], arr_dmmv_iq2_xxs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2937
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32", arr_dmmv_iq2_xs_f32_f32_len[s], arr_dmmv_iq2_xs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2938
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32", arr_dmmv_iq2_s_f32_f32_len[s], arr_dmmv_iq2_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2939
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32", arr_dmmv_iq3_xxs_f32_f32_len[s], arr_dmmv_iq3_xxs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2940
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32", arr_dmmv_iq3_s_f32_f32_len[s], arr_dmmv_iq3_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2941
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[s], arr_dmmv_iq4_xs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2942
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[s], arr_dmmv_iq4_nl_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2943
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[s], arr_dmmv_mxfp4_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2944
+
2945
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[s], arr_dmmv_f32_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
2946
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[s], arr_dmmv_f16_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
2947
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[s], arr_dmmv_bf16_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1);
2948
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[s], arr_dmmv_q4_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
2949
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[s], arr_dmmv_q4_1_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
2950
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[s], arr_dmmv_q5_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
2951
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[s], arr_dmmv_q5_1_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true);
2952
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[s], arr_dmmv_q8_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true);
2953
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[s], arr_dmmv_q2_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
2954
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[s], arr_dmmv_q3_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
2955
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[s], arr_dmmv_q4_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
2956
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32", arr_dmmv_q5_k_f16_f32_len[s], arr_dmmv_q5_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
2957
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32", arr_dmmv_q6_k_f16_f32_len[s], arr_dmmv_q6_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true);
2958
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32", arr_dmmv_iq1_s_f16_f32_len[s], arr_dmmv_iq1_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2959
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32", arr_dmmv_iq1_m_f16_f32_len[s], arr_dmmv_iq1_m_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2960
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32", arr_dmmv_iq2_xxs_f16_f32_len[s], arr_dmmv_iq2_xxs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2961
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32", arr_dmmv_iq2_xs_f16_f32_len[s], arr_dmmv_iq2_xs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2962
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32", arr_dmmv_iq2_s_f16_f32_len[s], arr_dmmv_iq2_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2963
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32", arr_dmmv_iq3_xxs_f16_f32_len[s], arr_dmmv_iq3_xxs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2964
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32", arr_dmmv_iq3_s_f16_f32_len[s], arr_dmmv_iq3_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2965
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[s], arr_dmmv_iq4_xs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2966
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[s], arr_dmmv_iq4_nl_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2967
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[s], arr_dmmv_mxfp4_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true);
2968
+ }
2754
2969
  }
2755
2970
 
2756
2971
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -2775,6 +2990,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2775
2990
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2776
2991
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2777
2992
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2993
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2778
2994
 
2779
2995
  // dequant shaders
2780
2996
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -2797,6 +3013,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2797
3013
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2798
3014
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
2799
3015
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3016
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
2800
3017
 
2801
3018
  // get_rows
2802
3019
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -2816,6 +3033,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2816
3033
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2817
3034
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2818
3035
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3036
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2819
3037
 
2820
3038
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2821
3039
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -2834,9 +3052,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
2834
3052
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2835
3053
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2836
3054
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3055
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2837
3056
 
2838
3057
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2839
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
3058
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
2840
3059
  ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2841
3060
 
2842
3061
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -2846,12 +3065,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
2846
3065
  ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
2847
3066
  }
2848
3067
  }
2849
- ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
3068
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 12 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2850
3069
 
2851
3070
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2852
3071
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2853
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2854
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
3072
+
3073
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
3074
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
3075
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
3076
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
3077
+
2855
3078
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2856
3079
  ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2857
3080
 
@@ -2921,22 +3144,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
2921
3144
  };
2922
3145
 
2923
3146
  bool rte = device->float_controls_rte_fp16;
2924
- #define CREATE_BINARY(name, namemod, spec) \
3147
+ #define CREATE_BINARY(name, namemod, spec, bindings) \
2925
3148
  for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
2926
3149
  ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
2927
3150
  #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
2928
- "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
2929
-
2930
- CREATE_BINARY(add, , {0})
2931
- CREATE_BINARY(add, _norepeat, {1})
2932
- CREATE_BINARY(sub, , {0})
2933
- CREATE_BINARY(sub, _norepeat, {1})
2934
- CREATE_BINARY(mul, , {0})
2935
- CREATE_BINARY(mul, _norepeat, {1})
2936
- CREATE_BINARY(div, , {0})
2937
- CREATE_BINARY(div, _norepeat, {1})
3151
+ "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3152
+
3153
+ CREATE_BINARY(add, , {0}, 4)
3154
+ CREATE_BINARY(add, _norepeat, {1}, 4)
3155
+ CREATE_BINARY(sub, , {0}, 3)
3156
+ CREATE_BINARY(sub, _norepeat, {1}, 3)
3157
+ CREATE_BINARY(mul, , {0}, 3)
3158
+ CREATE_BINARY(mul, _norepeat, {1}, 3)
3159
+ CREATE_BINARY(div, , {0}, 3)
3160
+ CREATE_BINARY(div, _norepeat, {1}, 3)
3161
+ CREATE_BINARY(add_rms, , {0}, 4)
3162
+ CREATE_BINARY(add_rms, _norepeat, {1}, 4)
2938
3163
  #undef CREATE_BINARY
2939
3164
 
3165
+ if (device->multi_add) {
3166
+ for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
3167
+ ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
3168
+ ggml_vk_create_pipeline(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
3169
+ }
3170
+ }
3171
+
3172
+ ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
3173
+
2940
3174
  ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2941
3175
 
2942
3176
  ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -2950,6 +3184,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2950
3184
  ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2951
3185
 
2952
3186
  ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3187
+ ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2953
3188
  ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2954
3189
  ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2955
3190
 
@@ -2966,6 +3201,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2966
3201
  ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
2967
3202
  ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2968
3203
 
3204
+ CREATE_UNARY(exp)
2969
3205
  CREATE_UNARY(gelu)
2970
3206
  CREATE_UNARY(gelu_erf)
2971
3207
  CREATE_UNARY(gelu_quick)
@@ -2987,6 +3223,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2987
3223
  CREATE_GLU(geglu)
2988
3224
  CREATE_GLU(reglu)
2989
3225
  CREATE_GLU(swiglu)
3226
+ CREATE_GLU(swiglu_oai)
2990
3227
  CREATE_GLU(geglu_erf)
2991
3228
  CREATE_GLU(geglu_quick)
2992
3229
  #undef CREATE_GLU
@@ -2996,10 +3233,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
2996
3233
 
2997
3234
  ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
2998
3235
 
2999
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3000
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
3001
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3002
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
3236
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3237
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
3238
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3239
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
3003
3240
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3004
3241
 
3005
3242
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -3019,11 +3256,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
3019
3256
  ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
3020
3257
  }
3021
3258
 
3022
- ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
3259
+ for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
3260
+ ggml_vk_create_pipeline(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
3261
+ }
3023
3262
 
3024
3263
  ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3025
3264
 
3026
- ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3265
+ ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
3027
3266
 
3028
3267
  ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
3029
3268
 
@@ -3046,44 +3285,114 @@ static void ggml_vk_load_shaders(vk_device& device) {
3046
3285
 
3047
3286
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3048
3287
 
3288
+ ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3289
+
3049
3290
  // conv2d
3050
- uint32_t conv2d_WG_SIZE = 256;
3051
- uint32_t conv2d_BS_K = 128;
3052
- uint32_t conv2d_BS_CRS = 16;
3053
- uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
3054
- if (device->subgroup_shuffle &&
3055
- device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316
3056
- use_collectives = 1;
3057
- conv2d_BS_CRS = std::min(
3058
- device->subgroup_size,
3059
- conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used.
3060
- }
3061
- uint32_t conv2d_BS_NPQ = 128;
3062
- uint32_t conv2d_TS_K = 8;
3063
- uint32_t conv2d_shmem_req =
3064
- (conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float);
3065
- if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
3066
- conv2d_BS_CRS = 8;
3067
- if (use_collectives) {
3068
- conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
3069
- }
3070
- }
3071
-
3072
- if (use_collectives) {
3073
- ggml_vk_create_pipeline(
3074
- device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3075
- sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3076
- { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
3077
- } else {
3078
- ggml_vk_create_pipeline(
3079
- device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3080
- sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3081
- { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
3082
- false);
3291
+ for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
3292
+ uint32_t conv2d_WG_SIZE = 256;
3293
+ uint32_t conv2d_BS_K = 128;
3294
+ uint32_t conv2d_BS_CRS = 16;
3295
+ uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
3296
+ uint32_t conv2d_BS_NPQ = 128;
3297
+ uint32_t conv2d_TS_K = 8;
3298
+ uint32_t conv2d_SHMEM_PAD = 4;
3299
+ bool conv2d_UNROLL = true;
3300
+
3301
+ #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
3302
+ if (device->coopmat2) {
3303
+ conv2d_SHMEM_PAD = 8; // 8 float16_t
3304
+ }
3305
+ #endif
3306
+
3307
+ if (device->vendor_id == VK_VENDOR_ID_INTEL) {
3308
+ conv2d_SHMEM_PAD = 0;
3309
+ conv2d_UNROLL = false;
3310
+ } else if (device->vendor_id == VK_VENDOR_ID_AMD) {
3311
+ conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4;
3312
+ }
3313
+
3314
+ switch (s) {
3315
+ default:
3316
+ case CONV_SHAPE_128x128:
3317
+ conv2d_BS_K = 128;
3318
+ conv2d_BS_NPQ = 128;
3319
+ conv2d_BS_CRS = 16;
3320
+ if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) {
3321
+ conv2d_UNROLL = false;
3322
+ }
3323
+ break;
3324
+ case CONV_SHAPE_64x32:
3325
+ conv2d_BS_K = 64;
3326
+ conv2d_BS_NPQ = 32;
3327
+ conv2d_BS_CRS = 32;
3328
+ conv2d_TS_K = 4;
3329
+ break;
3330
+ case CONV_SHAPE_32x256:
3331
+ conv2d_BS_K = 32;
3332
+ conv2d_BS_NPQ = 256;
3333
+ conv2d_BS_CRS = 16;
3334
+ break;
3335
+ }
3336
+
3337
+ // Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math.
3338
+ bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA ||
3339
+ device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
3340
+ bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD ||
3341
+ device->architecture == vk_device_architecture::AMD_GCN;
3342
+
3343
+ if (device->subgroup_shuffle &&
3344
+ device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316.
3345
+ allow_collectives_nv &&
3346
+ allow_collectives_amd) {
3347
+ use_collectives = 1;
3348
+ conv2d_BS_CRS = std::min(
3349
+ device->subgroup_size,
3350
+ conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
3351
+ }
3352
+
3353
+ uint32_t conv2d_shmem_req =
3354
+ (conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
3355
+ if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
3356
+ conv2d_BS_CRS = 8;
3357
+ if (use_collectives) {
3358
+ conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
3359
+ }
3360
+ }
3361
+
3362
+ std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
3363
+ std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
3364
+
3365
+ #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
3366
+ if (device->coopmat2) {
3367
+ ggml_vk_create_pipeline(
3368
+ device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3,
3369
+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3370
+ ggml_vk_create_pipeline(
3371
+ device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3,
3372
+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3373
+ } else
3374
+ #endif
3375
+ if (conv2d_UNROLL) {
3376
+ ggml_vk_create_pipeline(
3377
+ device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
3378
+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3379
+ ggml_vk_create_pipeline(
3380
+ device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3,
3381
+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3382
+ } else {
3383
+ ggml_vk_create_pipeline(
3384
+ device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3385
+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3386
+ ggml_vk_create_pipeline(
3387
+ device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3388
+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3389
+ }
3083
3390
  }
3084
3391
 
3085
3392
  ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3086
3393
  ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3394
+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3395
+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
3087
3396
 
3088
3397
  for (auto &c : compiles) {
3089
3398
  c.wait();
@@ -3125,6 +3434,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
3125
3434
  const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
3126
3435
  device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
3127
3436
 
3437
+ const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM");
3438
+ device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr;
3439
+
3128
3440
  bool fp16_storage = false;
3129
3441
  bool fp16_compute = false;
3130
3442
  bool maintenance4_support = false;
@@ -3269,6 +3581,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
3269
3581
  device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
3270
3582
  (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
3271
3583
 
3584
+ device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
3585
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);
3586
+
3272
3587
  const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
3273
3588
 
3274
3589
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -3402,6 +3717,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
3402
3717
 
3403
3718
  device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
3404
3719
 
3720
+ device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
3721
+ device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
3722
+ vk12_features.runtimeDescriptorArray &&
3723
+ device->vendor_id != VK_VENDOR_ID_INTEL &&
3724
+ getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
3725
+
3405
3726
  if (device->subgroup_size_control) {
3406
3727
  device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
3407
3728
  device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
@@ -3412,9 +3733,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
3412
3733
  (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
3413
3734
  subgroup_size_control_features.subgroupSizeControl;
3414
3735
 
3415
- if (device->subgroup_size_control) {
3416
- device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
3417
- }
3736
+ device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
3418
3737
 
3419
3738
  #if defined(VK_KHR_cooperative_matrix)
3420
3739
  device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
@@ -3715,6 +4034,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
3715
4034
 
3716
4035
  device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
3717
4036
 
4037
+ device->add_rms_fusion = !device->disable_fusion &&
4038
+ device->subgroup_add &&
4039
+ device->vendor_id != VK_VENDOR_ID_INTEL;
4040
+ device->partials_binding_alignment =
4041
+ std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
4042
+
3718
4043
  return device;
3719
4044
  }
3720
4045
 
@@ -4139,6 +4464,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
4139
4464
  case GGML_TYPE_IQ3_S:
4140
4465
  case GGML_TYPE_IQ4_XS:
4141
4466
  case GGML_TYPE_IQ4_NL:
4467
+ case GGML_TYPE_MXFP4:
4142
4468
  break;
4143
4469
  default:
4144
4470
  return nullptr;
@@ -4209,6 +4535,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
4209
4535
  case GGML_TYPE_IQ3_S:
4210
4536
  case GGML_TYPE_IQ4_XS:
4211
4537
  case GGML_TYPE_IQ4_NL:
4538
+ case GGML_TYPE_MXFP4:
4212
4539
  break;
4213
4540
  default:
4214
4541
  return nullptr;
@@ -4224,7 +4551,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
4224
4551
  return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
4225
4552
  }
4226
4553
 
4227
- static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) {
4554
+ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) {
4228
4555
  VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
4229
4556
  GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
4230
4557
  GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
@@ -4252,12 +4579,30 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
4252
4579
  case GGML_TYPE_IQ3_S:
4253
4580
  case GGML_TYPE_IQ4_XS:
4254
4581
  case GGML_TYPE_IQ4_NL:
4582
+ case GGML_TYPE_MXFP4:
4255
4583
  break;
4256
4584
  default:
4257
4585
  return nullptr;
4258
4586
  }
4259
4587
 
4260
- return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1];
4588
+ // heuristic to choose workgroup size
4589
+ uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
4590
+ if (ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
4591
+ // Prefer larger workgroups when M is small, to spread the work out more
4592
+ // and keep more SMs busy.
4593
+ // q6_k seems to prefer small workgroup size even for "medium" values of M.
4594
+ if (a_type == GGML_TYPE_Q6_K) {
4595
+ if (m < 4096 && k >= 1024) {
4596
+ dmmv_wg = DMMV_WG_SIZE_LARGE;
4597
+ }
4598
+ } else {
4599
+ if (m <= 8192 && k >= 1024) {
4600
+ dmmv_wg = DMMV_WG_SIZE_LARGE;
4601
+ }
4602
+ }
4603
+ }
4604
+
4605
+ return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1];
4261
4606
  }
4262
4607
 
4263
4608
  static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
@@ -4306,12 +4651,23 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
4306
4651
  case GGML_TYPE_IQ3_S:
4307
4652
  case GGML_TYPE_IQ4_XS:
4308
4653
  case GGML_TYPE_IQ4_NL:
4654
+ case GGML_TYPE_MXFP4:
4309
4655
  break;
4310
4656
  default:
4311
4657
  return nullptr;
4312
4658
  }
4313
4659
 
4314
- return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
4660
+ // XXX TODO 'prec' is not actually allowed in mul_mat_id.
4661
+ bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
4662
+ bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr;
4663
+ bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr;
4664
+
4665
+ if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
4666
+ return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc;
4667
+ } else {
4668
+ GGML_ASSERT(support_fp32acc);
4669
+ return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
4670
+ }
4315
4671
  }
4316
4672
 
4317
4673
  static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
@@ -4341,6 +4697,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
4341
4697
  case GGML_TYPE_IQ3_S:
4342
4698
  case GGML_TYPE_IQ4_XS:
4343
4699
  case GGML_TYPE_IQ4_NL:
4700
+ case GGML_TYPE_MXFP4:
4344
4701
  break;
4345
4702
  default:
4346
4703
  return nullptr;
@@ -4526,6 +4883,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
4526
4883
  std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
4527
4884
  GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
4528
4885
  GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
4886
+ GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
4529
4887
 
4530
4888
  vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
4531
4889
  vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
@@ -4648,7 +5006,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
4648
5006
  }
4649
5007
  }
4650
5008
 
4651
- ggml_vk_sync_buffers(subctx);
5009
+ ggml_vk_sync_buffers(ctx, subctx);
4652
5010
  subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
4653
5011
  return;
4654
5012
  }
@@ -4663,7 +5021,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
4663
5021
  ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
4664
5022
  VkBufferCopy buf_copy{ 0, offset, copy_size };
4665
5023
 
4666
- ggml_vk_sync_buffers(subctx);
5024
+ ggml_vk_sync_buffers(ctx, subctx);
4667
5025
  vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
4668
5026
 
4669
5027
  for (uint64_t i3 = 0; i3 < ne3; i3++) {
@@ -4717,7 +5075,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
4717
5075
  }
4718
5076
  }
4719
5077
 
4720
- ggml_vk_sync_buffers(subctx);
5078
+ ggml_vk_sync_buffers(nullptr, subctx);
4721
5079
  subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
4722
5080
  return;
4723
5081
  }
@@ -4738,7 +5096,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
4738
5096
  offset,
4739
5097
  copy_size};
4740
5098
 
4741
- ggml_vk_sync_buffers(subctx);
5099
+ ggml_vk_sync_buffers(nullptr, subctx);
4742
5100
  vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
4743
5101
 
4744
5102
  if (width == spitch) {
@@ -4818,7 +5176,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
4818
5176
 
4819
5177
  if (buf != nullptr) {
4820
5178
  // Memory is pinned, use as staging buffer
4821
- ggml_vk_sync_buffers(subctx);
5179
+ ggml_vk_sync_buffers(nullptr, subctx);
4822
5180
  subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
4823
5181
 
4824
5182
  return;
@@ -4835,7 +5193,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
4835
5193
 
4836
5194
  vk_buffer& staging_buffer = src->device->sync_staging;
4837
5195
 
4838
- ggml_vk_sync_buffers(subctx);
5196
+ ggml_vk_sync_buffers(nullptr, subctx);
4839
5197
  subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
4840
5198
 
4841
5199
  deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
@@ -4933,26 +5291,37 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
4933
5291
  ggml_vk_queue_command_pools_cleanup(dst->device);
4934
5292
  }
4935
5293
 
4936
- static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
5294
+ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) {
4937
5295
  VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
4938
5296
 
4939
5297
  uint32_t split_k = 1;
4940
- if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
5298
+ if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
4941
5299
  // If k is 'large' and the SMs will fill less than halfway, use split_k.
4942
5300
  uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
4943
5301
  uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
4944
- if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
4945
- split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
4946
- // Clamp to 2 or 4
4947
- split_k = std::min(split_k, 4u);
4948
- if (split_k == 3) {
4949
- split_k = 2;
5302
+
5303
+ if (k >= 2048) {
5304
+ if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) {
5305
+ split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
5306
+ } else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) {
5307
+ split_k = 3;
4950
5308
  }
4951
- if (ctx->device->coopmat2) {
4952
- // coopmat2 shader expects splits to be aligned to 256
4953
- while (split_k > 1 && ((k / split_k) % 256) != 0) {
4954
- split_k /= 2;
5309
+ // Cap the split at 8x. Unless k is huge this is a lot of overhead.
5310
+ split_k = std::min(split_k, 8u);
5311
+
5312
+ // ggml_vk_matmul will align the splits to be a multiple of 256.
5313
+ // If this rounded up size would cause the last split to be empty,
5314
+ // then reduce the split count.
5315
+ while (true) {
5316
+ if (split_k == 1) {
5317
+ break;
4955
5318
  }
5319
+ uint32_t k_split = CEIL_DIV(k, split_k);
5320
+ k_split = ROUNDUP_POW2(k_split, 256);
5321
+ if (k_split * (split_k - 1) < k) {
5322
+ break;
5323
+ }
5324
+ split_k--;
4956
5325
  }
4957
5326
  }
4958
5327
  }
@@ -4964,9 +5333,22 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
4964
5333
  VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
4965
5334
 
4966
5335
  if (ctx->device->coopmat2) {
5336
+ const uint32_t shader_core_count = ctx->device->shader_core_count;
5337
+ const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
5338
+ const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]);
5339
+
4967
5340
  // Use large shader when the N dimension is greater than the medium shader's tile size
4968
5341
  uint32_t crossover_large = mmp->m->wg_denoms[1];
4969
- if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
5342
+
5343
+ // Prefer large over medium if either:
5344
+ // - medium or large tiles would overfill the GPU
5345
+ // - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not
5346
+ // (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)
5347
+ bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||
5348
+ // split_k==3 with large tiles likely better than medium tiles with no split_k.
5349
+ (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
5350
+
5351
+ if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
4970
5352
  return aligned ? mmp->a_l : mmp->l;
4971
5353
  }
4972
5354
  // Use medium shader when the N dimension is greater than the small shader's tile size
@@ -5001,21 +5383,29 @@ static void ggml_vk_matmul(
5001
5383
  uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
5002
5384
  uint32_t padded_n) {
5003
5385
  VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
5004
- ggml_vk_sync_buffers(subctx);
5005
5386
  if (split_k == 1) {
5006
5387
  const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
5007
5388
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
5008
5389
  return;
5009
5390
  }
5010
5391
 
5392
+ if (ctx->prealloc_split_k_need_sync) {
5393
+ ggml_vk_sync_buffers(ctx, subctx);
5394
+ }
5395
+
5011
5396
  GGML_ASSERT(batch_stride_d == m * n);
5012
5397
 
5013
- const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
5398
+ // Round the split size up to a multiple of 256 (k-quant alignment)
5399
+ uint32_t k_split = CEIL_DIV(k, split_k);
5400
+ k_split = ROUNDUP_POW2(k_split, 256);
5401
+
5402
+ const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
5014
5403
  // Make sure enough workgroups get assigned for split k to work
5015
5404
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
5016
- ggml_vk_sync_buffers(subctx);
5405
+ ggml_vk_sync_buffers(ctx, subctx);
5017
5406
  const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
5018
5407
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
5408
+ ctx->prealloc_split_k_need_sync = true;
5019
5409
  }
5020
5410
 
5021
5411
  static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
@@ -5060,7 +5450,6 @@ static void ggml_vk_matmul_id(
5060
5450
  "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
5061
5451
  "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
5062
5452
  "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
5063
- ggml_vk_sync_buffers(subctx);
5064
5453
  const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
5065
5454
  nei0, nei1, nbi1, ne11, padded_n };
5066
5455
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as });
@@ -5191,8 +5580,8 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
5191
5580
  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
5192
5581
  };
5193
5582
  init_pushconst_fastdiv(pc);
5194
- ggml_vk_sync_buffers(subctx);
5195
5583
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
5584
+ ggml_vk_sync_buffers(ctx, subctx);
5196
5585
  }
5197
5586
 
5198
5587
  static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
@@ -5210,14 +5599,14 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
5210
5599
 
5211
5600
  vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
5212
5601
 
5213
- ggml_vk_sync_buffers(subctx);
5214
5602
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
5603
+ ggml_vk_sync_buffers(ctx, subctx);
5215
5604
  }
5216
5605
 
5217
5606
  static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
5218
- VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
5219
- std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
5220
- std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
5607
+ VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
5608
+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
5609
+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
5221
5610
  std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
5222
5611
  GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
5223
5612
  GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
@@ -5406,18 +5795,39 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
5406
5795
  GGML_ASSERT(qy_sz == y_sz);
5407
5796
  }
5408
5797
 
5798
+ if (x_non_contig || qx_needs_dequant) {
5799
+ if (ctx->prealloc_x_need_sync) {
5800
+ ggml_vk_sync_buffers(ctx, subctx);
5801
+ }
5802
+ }
5803
+ if (y_non_contig || quantize_y) {
5804
+ if (ctx->prealloc_y_need_sync) {
5805
+ ggml_vk_sync_buffers(ctx, subctx);
5806
+ }
5807
+ }
5808
+
5409
5809
  if (x_non_contig) {
5410
5810
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
5411
5811
  } else if (qx_needs_dequant) {
5412
5812
  const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
5413
- ggml_vk_sync_buffers(subctx);
5414
5813
  ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
5814
+ ggml_vk_sync_buffers(ctx, subctx);
5415
5815
  }
5416
5816
  if (y_non_contig) {
5417
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
5817
+ if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
5818
+ ctx->prealloc_y_last_tensor_used != src1) {
5819
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
5820
+ ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
5821
+ ctx->prealloc_y_last_tensor_used = src1;
5822
+ }
5418
5823
  }
5419
5824
  if (quantize_y) {
5420
- ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
5825
+ if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
5826
+ ctx->prealloc_y_last_tensor_used != src1) {
5827
+ ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
5828
+ ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
5829
+ ctx->prealloc_y_last_tensor_used = src1;
5830
+ }
5421
5831
  }
5422
5832
 
5423
5833
  uint32_t stride_batch_x = ne00*ne01;
@@ -5440,6 +5850,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
5440
5850
  ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
5441
5851
  split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
5442
5852
  ); // NOLINT
5853
+
5854
+ if (x_non_contig || qx_needs_dequant) {
5855
+ ctx->prealloc_x_need_sync = true;
5856
+ }
5857
+ if (y_non_contig || quantize_y) {
5858
+ ctx->prealloc_y_need_sync = true;
5859
+ }
5443
5860
  }
5444
5861
 
5445
5862
  static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -5523,7 +5940,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
5523
5940
  } else {
5524
5941
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
5525
5942
  }
5526
- vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11);
5943
+ vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00);
5527
5944
  GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
5528
5945
  GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
5529
5946
  GGML_ASSERT(dmmv != nullptr);
@@ -5586,13 +6003,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
5586
6003
  GGML_ASSERT(qy_sz == y_sz);
5587
6004
  }
5588
6005
 
6006
+ if (x_non_contig) {
6007
+ if (ctx->prealloc_x_need_sync) {
6008
+ ggml_vk_sync_buffers(ctx, subctx);
6009
+ }
6010
+ }
6011
+ if (y_non_contig) {
6012
+ if (ctx->prealloc_y_need_sync) {
6013
+ ggml_vk_sync_buffers(ctx, subctx);
6014
+ }
6015
+ }
6016
+
5589
6017
  if (x_non_contig) {
5590
6018
  GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
5591
6019
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
5592
6020
  }
5593
6021
  if (y_non_contig) {
5594
6022
  GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
5595
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
6023
+ if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
6024
+ ctx->prealloc_y_last_tensor_used != src1) {
6025
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
6026
+ ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
6027
+ ctx->prealloc_y_last_tensor_used = src1;
6028
+ }
5596
6029
  }
5597
6030
 
5598
6031
  // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
@@ -5624,10 +6057,16 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
5624
6057
  stride_batch_x, stride_batch_y, stride_batch_d,
5625
6058
  (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
5626
6059
  };
5627
- ggml_vk_sync_buffers(subctx);
5628
6060
  ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
5629
6061
  { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} },
5630
6062
  pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
6063
+
6064
+ if (x_non_contig) {
6065
+ ctx->prealloc_x_need_sync = true;
6066
+ }
6067
+ if (y_non_contig) {
6068
+ ctx->prealloc_y_need_sync = true;
6069
+ }
5631
6070
  }
5632
6071
 
5633
6072
  static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -5714,7 +6153,6 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
5714
6153
  workgroups_z /= gqa_ratio;
5715
6154
  }
5716
6155
 
5717
- ggml_vk_sync_buffers(subctx);
5718
6156
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z });
5719
6157
  }
5720
6158
 
@@ -5732,7 +6170,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
5732
6170
  const uint64_t ne00 = src0->ne[0];
5733
6171
  const uint64_t ne01 = src0->ne[1];
5734
6172
  const uint64_t ne02 = src0->ne[2];
5735
- // const uint64_t ne03 = src0->ne[3];
6173
+ const uint64_t ne03 = src0->ne[3];
5736
6174
 
5737
6175
  const uint64_t nb01 = src0->nb[1];
5738
6176
  const uint64_t nb02 = src0->nb[2];
@@ -5744,7 +6182,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
5744
6182
  const uint64_t ne12 = src1->ne[2];
5745
6183
  // const uint64_t ne13 = src1->ne[3];
5746
6184
 
6185
+ const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t));
6186
+ const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float));
6187
+ const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float));
6188
+
5747
6189
  GGML_ASSERT(ne11 == 1);
6190
+ GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op
5748
6191
 
5749
6192
  ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
5750
6193
  ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
@@ -5760,7 +6203,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
5760
6203
  src1_uma = d_Qy != nullptr;
5761
6204
  }
5762
6205
 
5763
- const uint64_t d_ne = ne01 * ne11 * ne12;
6206
+ const uint64_t d_ne = ne01 * ne11 * ne12 * ne03;
5764
6207
 
5765
6208
  const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
5766
6209
  const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
@@ -5795,10 +6238,9 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
5795
6238
  const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
5796
6239
 
5797
6240
  // compute
5798
- const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
5799
- ggml_vk_sync_buffers(subctx);
6241
+ const std::array<uint32_t, 12> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 };
5800
6242
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
5801
- { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
6243
+ { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
5802
6244
  }
5803
6245
 
5804
6246
  static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -5847,7 +6289,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
5847
6289
 
5848
6290
  const uint64_t nei0 = ids->ne[0];
5849
6291
  const uint64_t nei1 = ids->ne[1];
5850
- GGML_ASSERT(nei0 * nei1 <= 4096);
5851
6292
 
5852
6293
  const uint32_t nbi1 = ids->nb[1];
5853
6294
  const uint32_t nbi2 = ids->nb[2];
@@ -6008,16 +6449,32 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
6008
6449
  GGML_ASSERT(qy_sz == y_sz);
6009
6450
  }
6010
6451
 
6452
+ if (x_non_contig || qx_needs_dequant) {
6453
+ if (ctx->prealloc_x_need_sync) {
6454
+ ggml_vk_sync_buffers(ctx, subctx);
6455
+ }
6456
+ }
6457
+ if (y_non_contig) {
6458
+ if (ctx->prealloc_y_need_sync) {
6459
+ ggml_vk_sync_buffers(ctx, subctx);
6460
+ }
6461
+ }
6462
+
6011
6463
  if (x_non_contig) {
6012
6464
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
6013
6465
  } else if (qx_needs_dequant) {
6014
6466
  const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
6015
- ggml_vk_sync_buffers(subctx);
6016
6467
  ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
6017
6468
  { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
6469
+ ggml_vk_sync_buffers(ctx, subctx);
6018
6470
  }
6019
6471
  if (y_non_contig) {
6020
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
6472
+ if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
6473
+ ctx->prealloc_y_last_tensor_used != src1) {
6474
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
6475
+ ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
6476
+ ctx->prealloc_y_last_tensor_used = src1;
6477
+ }
6021
6478
  }
6022
6479
 
6023
6480
  uint32_t stride_batch_x = ne00*ne01;
@@ -6040,6 +6497,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
6040
6497
  stride_batch_x, stride_batch_y, ne20*ne21,
6041
6498
  n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
6042
6499
  ); // NOLINT
6500
+
6501
+ if (x_non_contig || qx_needs_dequant) {
6502
+ ctx->prealloc_x_need_sync = true;
6503
+ }
6504
+ if (y_non_contig) {
6505
+ ctx->prealloc_y_need_sync = true;
6506
+ }
6043
6507
  }
6044
6508
 
6045
6509
  static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) {
@@ -6199,13 +6663,29 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
6199
6663
  GGML_ASSERT(qy_sz == y_sz);
6200
6664
  }
6201
6665
 
6666
+ if (x_non_contig) {
6667
+ if (ctx->prealloc_x_need_sync) {
6668
+ ggml_vk_sync_buffers(ctx, subctx);
6669
+ }
6670
+ }
6671
+ if (y_non_contig) {
6672
+ if (ctx->prealloc_y_need_sync) {
6673
+ ggml_vk_sync_buffers(ctx, subctx);
6674
+ }
6675
+ }
6676
+
6202
6677
  if (x_non_contig) {
6203
6678
  GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
6204
6679
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
6205
6680
  }
6206
6681
  if (y_non_contig) {
6207
6682
  GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
6208
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
6683
+ if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
6684
+ ctx->prealloc_y_last_tensor_used != src1) {
6685
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
6686
+ ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
6687
+ ctx->prealloc_y_last_tensor_used = src1;
6688
+ }
6209
6689
  }
6210
6690
 
6211
6691
  uint32_t stride_batch_y = ne10*ne11;
@@ -6230,11 +6710,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
6230
6710
  (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
6231
6711
  (uint32_t)nei0, (uint32_t)ne11,
6232
6712
  };
6233
- ggml_vk_sync_buffers(subctx);
6234
6713
  ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
6235
6714
  { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 },
6236
6715
  vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } },
6237
6716
  pc, { groups_x, (uint32_t)nei0, groups_z });
6717
+
6718
+ if (x_non_contig) {
6719
+ ctx->prealloc_x_need_sync = true;
6720
+ }
6721
+ if (y_non_contig) {
6722
+ ctx->prealloc_y_need_sync = true;
6723
+ }
6238
6724
  }
6239
6725
 
6240
6726
  static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
@@ -6242,30 +6728,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
6242
6728
  if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
6243
6729
  ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
6244
6730
  } else {
6245
- // Split based on number of ids, to fit in shared memory
6246
- const uint32_t nei0 = (uint32_t)src2->ne[0];
6247
- const uint32_t nei1 = (uint32_t)src2->ne[1];
6248
-
6249
- GGML_ASSERT(nei0 <= 4096);
6250
- const uint32_t split_size = std::min(nei1, 4096u / nei0);
6251
-
6252
- ggml_tensor src1_copy = *src1;
6253
- ggml_tensor src2_copy = *src2;
6254
- ggml_tensor dst_copy = *dst;
6255
-
6256
- for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
6257
- const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
6258
-
6259
- src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
6260
- src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
6261
- dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
6262
-
6263
- src1_copy.ne[2] = n_tokens;
6264
- src2_copy.ne[1] = n_tokens;
6265
- dst_copy.ne[2] = n_tokens;
6266
-
6267
- ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
6268
- }
6731
+ ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
6269
6732
  }
6270
6733
  }
6271
6734
 
@@ -6298,18 +6761,21 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
6298
6761
  const uint32_t Br = coopmat1_flash_attention_num_large_rows;
6299
6762
  const uint32_t Bc = scalar_flash_attention_Bc;
6300
6763
 
6764
+ const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
6765
+
6301
6766
  const uint32_t acctype = f32acc ? 4 : 2;
6302
6767
  const uint32_t f16vec4 = 8;
6303
6768
 
6304
6769
  const uint32_t tmpsh = wg_size * sizeof(float);
6305
6770
  const uint32_t tmpshv4 = wg_size * 4 * acctype;
6306
6771
 
6307
- const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
6772
+ const uint32_t qstride = hsk_pad / 4 + 2;
6773
+ const uint32_t Qf = Br * qstride * f16vec4;
6308
6774
 
6309
6775
  const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
6310
6776
  const uint32_t sfsh = Bc * sfshstride * acctype;
6311
6777
 
6312
- const uint32_t kshstride = hsk / 4 + 2;
6778
+ const uint32_t kshstride = hsk_pad / 4 + 2;
6313
6779
  const uint32_t ksh = Bc * kshstride * f16vec4;
6314
6780
 
6315
6781
  const uint32_t slope = Br * sizeof(float);
@@ -6322,11 +6788,14 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
6322
6788
  return supported;
6323
6789
  }
6324
6790
 
6325
- static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
6791
+ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst, bool dryrun = false) {
6326
6792
  VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
6327
6793
  std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
6328
6794
  std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
6329
6795
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
6796
+ if (sinks) {
6797
+ std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3];
6798
+ }
6330
6799
  std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
6331
6800
 
6332
6801
  GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
@@ -6417,7 +6886,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6417
6886
  workgroups_y /= N;
6418
6887
  }
6419
6888
 
6420
- vk_pipeline *pipelines;
6421
6889
  bool small_rows = N <= get_fa_num_small_rows(path);
6422
6890
 
6423
6891
  // coopmat1 does not actually support "small rows" (it needs 16 rows).
@@ -6437,37 +6905,36 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6437
6905
  small_rows = true;
6438
6906
  }
6439
6907
 
6440
- bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
6441
-
6442
- FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
6443
-
6444
- switch (path) {
6445
- case FA_SCALAR:
6446
- pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
6447
- break;
6448
- case FA_COOPMAT1:
6449
- pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
6450
- break;
6451
- case FA_COOPMAT2:
6452
- pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
6453
- break;
6454
- default:
6455
- GGML_ASSERT(0);
6456
- }
6457
- assert(pipelines);
6458
-
6459
6908
  const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
6460
6909
  const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
6461
6910
  const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
6462
6911
 
6463
- bool aligned = (KV % pipelines[1]->align) == 0 &&
6912
+ uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
6913
+ bool aligned = (KV % alignment) == 0 &&
6464
6914
  // the "aligned" shader variant will forcibly align strides, for performance
6465
6915
  (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
6466
6916
 
6917
+ // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
6918
+ if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
6919
+ aligned = false;
6920
+ }
6467
6921
  // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
6468
6922
  GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
6469
6923
 
6470
- vk_pipeline pipeline = pipelines[aligned];
6924
+ bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
6925
+
6926
+ vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc);
6927
+
6928
+ vk_pipeline pipeline = nullptr;
6929
+
6930
+ auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
6931
+ auto it = pipelines.find(fa_pipeline_state);
6932
+ if (it != pipelines.end()) {
6933
+ pipeline = it->second;
6934
+ } else {
6935
+ pipelines[fa_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
6936
+ }
6937
+
6471
6938
  assert(pipeline);
6472
6939
 
6473
6940
  uint32_t split_kv = KV;
@@ -6483,7 +6950,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6483
6950
  if (split_k > 1) {
6484
6951
  // Try to evenly split KV into split_k chunks, but it needs to be a multiple
6485
6952
  // of "align", so recompute split_k based on that.
6486
- split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
6953
+ split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
6487
6954
  split_k = CEIL_DIV(KV, split_kv);
6488
6955
  workgroups_x = split_k;
6489
6956
  }
@@ -6525,10 +6992,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6525
6992
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
6526
6993
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
6527
6994
 
6528
- vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
6529
- size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
6995
+ vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr;
6996
+ size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0;
6530
6997
 
6531
- bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
6998
+ bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false;
6532
6999
 
6533
7000
  if (ctx->device->uma) {
6534
7001
  ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
@@ -6543,6 +7010,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6543
7010
  ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
6544
7011
  M_uma = d_M != nullptr;
6545
7012
  }
7013
+ if (sinks) {
7014
+ ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset);
7015
+ S_uma = d_S != nullptr;
7016
+ }
6546
7017
  }
6547
7018
 
6548
7019
 
@@ -6578,7 +7049,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6578
7049
  }
6579
7050
  }
6580
7051
 
6581
- uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
7052
+ if (!S_uma) {
7053
+ d_S = d_Q;
7054
+ s_buf_offset = q_buf_offset;
7055
+ if (sinks) {
7056
+ ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context;
7057
+ d_S = s_buf_ctx->dev_buffer;
7058
+ s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs;
7059
+ }
7060
+ }
7061
+
7062
+ uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
6582
7063
 
6583
7064
  const vk_flash_attn_push_constants pc = { N, KV,
6584
7065
  (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
@@ -6593,15 +7074,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6593
7074
  mask_n_head_log2, m0, m1,
6594
7075
  gqa_ratio, split_kv, split_k };
6595
7076
 
6596
- ggml_vk_sync_buffers(subctx);
6597
-
6598
7077
  if (split_k > 1) {
7078
+ if (ctx->prealloc_split_k_need_sync) {
7079
+ ggml_vk_sync_buffers(ctx, subctx);
7080
+ }
7081
+
6599
7082
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6600
7083
  {
6601
7084
  vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
6602
7085
  vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
6603
7086
  vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
6604
7087
  vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
7088
+ vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
6605
7089
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6606
7090
  },
6607
7091
  // We only use split_k when group query attention is enabled, which means
@@ -6610,14 +7094,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6610
7094
  // cancel out the divide by wg_denoms[0].
6611
7095
  pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
6612
7096
 
6613
- ggml_vk_sync_buffers(subctx);
6614
- const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
7097
+ ggml_vk_sync_buffers(ctx, subctx);
7098
+ const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
6615
7099
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
6616
7100
  {
6617
7101
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
7102
+ vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
6618
7103
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6619
7104
  },
6620
7105
  pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
7106
+ ctx->prealloc_split_k_need_sync = true;
6621
7107
  } else {
6622
7108
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6623
7109
  {
@@ -6625,13 +7111,42 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6625
7111
  vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
6626
7112
  vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
6627
7113
  vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
7114
+ vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
6628
7115
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6629
7116
  },
6630
7117
  pc, { workgroups_x, workgroups_y, workgroups_z });
6631
7118
  }
6632
7119
  }
6633
7120
 
6634
- static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
7121
+ static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst) {
7122
+ const ggml_tensor *src0 = dst->src[0];
7123
+ const ggml_tensor *src1 = dst->src[1];
7124
+
7125
+ // src0 - kernel: [KW, KH, Cin, Cout]
7126
+ // src1 - input: [W, H, Cin, N]
7127
+ // dst - result: [OW, OH, Cout, N]
7128
+
7129
+ // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
7130
+ auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
7131
+ return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
7132
+ };
7133
+ // parallelize in {OW/BS_K, OH/BS_NPQ, 1}
7134
+ int64_t W = src1->ne[0];
7135
+ int64_t H = src1->ne[1];
7136
+ int64_t KW = src0->ne[0];
7137
+ int64_t KH = src0->ne[1];
7138
+ int64_t Cout = src0->ne[3];
7139
+ int64_t N = src1->ne[3];
7140
+ int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
7141
+ int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
7142
+ int64_t NPQ = N * OW * OH;
7143
+
7144
+ // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
7145
+ std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
7146
+ return elements;
7147
+ }
7148
+
7149
+ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
6635
7150
  switch (op) {
6636
7151
  case GGML_OP_GET_ROWS:
6637
7152
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -6659,8 +7174,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6659
7174
  switch (op) {
6660
7175
  case GGML_OP_ADD:
6661
7176
  {
6662
- auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6663
- return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
7177
+ if (ctx->num_additional_fused_ops > 0) {
7178
+ if (ctx->do_add_rms_partials) {
7179
+ return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];
7180
+ } else {
7181
+ return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
7182
+ }
7183
+ }
7184
+ if (ctx->do_add_rms_partials) {
7185
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
7186
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
7187
+ } else {
7188
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
7189
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
7190
+ }
6664
7191
  }
6665
7192
  case GGML_OP_SUB:
6666
7193
  {
@@ -6681,6 +7208,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6681
7208
  break;
6682
7209
  }
6683
7210
  return nullptr;
7211
+ case GGML_OP_ADD_ID:
7212
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
7213
+ return ctx->device->pipeline_add_id_f32;
7214
+ }
7215
+ return nullptr;
6684
7216
  case GGML_OP_CONCAT:
6685
7217
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6686
7218
  return ctx->device->pipeline_concat_f32;
@@ -6715,6 +7247,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6715
7247
  return ctx->device->pipeline_sqr_f32;
6716
7248
  }
6717
7249
  return nullptr;
7250
+ case GGML_OP_SQRT:
7251
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7252
+ return ctx->device->pipeline_sqrt_f32;
7253
+ }
7254
+ return nullptr;
6718
7255
  case GGML_OP_SIN:
6719
7256
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6720
7257
  return ctx->device->pipeline_sin_f32;
@@ -6773,7 +7310,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6773
7310
  return nullptr;
6774
7311
  case GGML_OP_RMS_NORM:
6775
7312
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6776
- return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
7313
+ if (ctx->do_add_rms_partials) {
7314
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
7315
+ } else {
7316
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
7317
+ }
6777
7318
  }
6778
7319
  return nullptr;
6779
7320
  case GGML_OP_RMS_NORM_BACK:
@@ -6794,6 +7335,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6794
7335
  }
6795
7336
 
6796
7337
  switch (ggml_get_unary_op(dst)) {
7338
+ case GGML_UNARY_OP_EXP:
7339
+ return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
6797
7340
  case GGML_UNARY_OP_SILU:
6798
7341
  return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
6799
7342
  case GGML_UNARY_OP_GELU:
@@ -6826,6 +7369,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6826
7369
  return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
6827
7370
  case GGML_GLU_OP_SWIGLU:
6828
7371
  return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
7372
+ case GGML_GLU_OP_SWIGLU_OAI:
7373
+ return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
6829
7374
  case GGML_GLU_OP_GEGLU_ERF:
6830
7375
  return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
6831
7376
  case GGML_GLU_OP_GEGLU_QUICK:
@@ -6841,6 +7386,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6841
7386
  return nullptr;
6842
7387
  case GGML_OP_SOFT_MAX:
6843
7388
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
7389
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
6844
7390
 
6845
7391
  if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
6846
7392
  return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
@@ -6895,11 +7441,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6895
7441
  }
6896
7442
  case GGML_OP_ARGSORT:
6897
7443
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
6898
- return ctx->device->pipeline_argsort_f32;
7444
+ uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
7445
+ return ctx->device->pipeline_argsort_f32[idx];
6899
7446
  }
6900
7447
  return nullptr;
6901
7448
  case GGML_OP_SUM:
6902
7449
  case GGML_OP_SUM_ROWS:
7450
+ case GGML_OP_MEAN:
6903
7451
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6904
7452
  return ctx->device->pipeline_sum_rows_f32;
6905
7453
  }
@@ -6952,15 +7500,44 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6952
7500
  return ctx->device->pipeline_opt_step_adamw_f32;
6953
7501
  }
6954
7502
  return nullptr;
7503
+ case GGML_OP_OPT_STEP_SGD:
7504
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7505
+ return ctx->device->pipeline_opt_step_sgd_f32;
7506
+ }
7507
+ return nullptr;
6955
7508
  case GGML_OP_LEAKY_RELU:
6956
7509
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6957
7510
  return ctx->device->pipeline_leaky_relu_f32;
6958
7511
  }
6959
7512
  return nullptr;
6960
7513
  case GGML_OP_CONV_2D:
6961
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
7514
+ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
6962
7515
  ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
6963
- return ctx->device->pipeline_conv2d_f32;
7516
+ auto elements = ggml_vk_get_conv_elements(dst);
7517
+ vk_conv_shapes shape;
7518
+
7519
+ uint32_t tiles[CONV_SHAPE_COUNT];
7520
+ for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) {
7521
+ tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]);
7522
+ }
7523
+
7524
+ // We can't query number of shader cores on Intel, use 32 as a placeholder
7525
+ // so small convolutions will still choose a smaller tile.
7526
+ const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;
7527
+
7528
+ if (elements[0] > 64 && tiles[CONV_SHAPE_128x128] >= shader_core_count * 2) {
7529
+ shape = CONV_SHAPE_128x128;
7530
+ } else if (elements[0] <= 32 && tiles[CONV_SHAPE_32x256] >= shader_core_count * 2) {
7531
+ shape = CONV_SHAPE_32x256;
7532
+ } else {
7533
+ shape = CONV_SHAPE_64x32;
7534
+ }
7535
+
7536
+ if (src0->type == GGML_TYPE_F32) {
7537
+ return ctx->device->pipeline_conv2d_f32[shape];
7538
+ } else if (src0->type == GGML_TYPE_F16) {
7539
+ return ctx->device->pipeline_conv2d_f16_f32[shape];
7540
+ }
6964
7541
  }
6965
7542
  return nullptr;
6966
7543
  case GGML_OP_CONV_2D_DW:
@@ -6970,6 +7547,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6970
7547
  } else if (ggml_is_contiguous_channels(src1)) {
6971
7548
  return ctx->device->pipeline_conv2d_dw_cwhn_f32;
6972
7549
  }
7550
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
7551
+ if (ggml_is_contiguous(src1)) {
7552
+ return ctx->device->pipeline_conv2d_dw_whcn_f16_f32;
7553
+ } else if (ggml_is_contiguous_channels(src1)) {
7554
+ return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32;
7555
+ }
6973
7556
  }
6974
7557
  return nullptr;
6975
7558
  default:
@@ -6987,9 +7570,11 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
6987
7570
  case GGML_OP_SUB:
6988
7571
  case GGML_OP_MUL:
6989
7572
  case GGML_OP_DIV:
7573
+ case GGML_OP_ADD_ID:
6990
7574
  case GGML_OP_CONCAT:
6991
7575
  case GGML_OP_UPSCALE:
6992
7576
  case GGML_OP_SQR:
7577
+ case GGML_OP_SQRT:
6993
7578
  case GGML_OP_SIN:
6994
7579
  case GGML_OP_COS:
6995
7580
  case GGML_OP_CLAMP:
@@ -7001,6 +7586,9 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
7001
7586
  case GGML_OP_CONV_2D_DW:
7002
7587
  case GGML_OP_IM2COL:
7003
7588
  case GGML_OP_SET_ROWS:
7589
+ case GGML_OP_SUM:
7590
+ case GGML_OP_SUM_ROWS:
7591
+ case GGML_OP_MEAN:
7004
7592
  return true;
7005
7593
  default:
7006
7594
  return false;
@@ -7035,6 +7623,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
7035
7623
  GGML_UNUSED(src2);
7036
7624
  }
7037
7625
 
7626
+ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
7627
+ const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
7628
+ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
7629
+
7630
+ p.misalign_offsets = (a_offset << 16) | d_offset;
7631
+
7632
+ GGML_UNUSED(src1);
7633
+ GGML_UNUSED(src2);
7634
+ }
7635
+
7038
7636
  template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
7039
7637
  const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
7040
7638
  const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
@@ -7185,10 +7783,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7185
7783
  d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
7186
7784
 
7187
7785
  if (op_supports_incontiguous) {
7188
- x_sz = ggml_nbytes(src0);
7189
- y_sz = use_src1 ? ggml_nbytes(src1) : 0;
7190
- z_sz = use_src2 ? ggml_nbytes(src2) : 0;
7191
- d_sz = ggml_nbytes(dst);
7786
+ x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
7787
+ y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
7788
+ z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
7789
+ d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
7192
7790
 
7193
7791
  if (x_buf_offset + x_sz >= d_X->size) {
7194
7792
  x_sz = VK_WHOLE_SIZE;
@@ -7216,6 +7814,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7216
7814
  case GGML_OP_SOFT_MAX:
7217
7815
  case GGML_OP_SOFT_MAX_BACK:
7218
7816
  case GGML_OP_SUM_ROWS:
7817
+ case GGML_OP_MEAN:
7219
7818
  case GGML_OP_ARGMAX:
7220
7819
  {
7221
7820
  const uint32_t nr = ggml_nrows(src0);
@@ -7228,7 +7827,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7228
7827
  }
7229
7828
  } break;
7230
7829
  case GGML_OP_RMS_NORM:
7231
- elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7830
+ if (ctx->do_add_rms_partials) {
7831
+ // Run one element per thread, 128 threads per workgroup
7832
+ elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
7833
+ } else {
7834
+ elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7835
+ }
7232
7836
  break;
7233
7837
 
7234
7838
  case GGML_OP_SUM:
@@ -7287,35 +7891,15 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7287
7891
  } break;
7288
7892
  case GGML_OP_CONV_2D:
7289
7893
  {
7290
- // src0 - kernel: [KW, KH, Cin, Cout]
7291
- // src1 - input: [W, H, Cin, N]
7292
- // dst - result: [OW, OH, Cout, N]
7293
-
7294
- // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
7295
- auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
7296
- return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
7297
- };
7298
- // parallelize in {OW/BS_K, OH/BS_NPQ, 1}
7299
- int64_t W = src1->ne[0];
7300
- int64_t H = src1->ne[1];
7301
- int64_t KW = src0->ne[0];
7302
- int64_t KH = src0->ne[1];
7303
- int64_t Cout = src0->ne[3];
7304
- int64_t N = src1->ne[3];
7305
- int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
7306
- int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
7307
- int64_t NPQ = N * OW * OH;
7308
-
7309
- // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
7310
- elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
7311
- }
7312
- break;
7894
+ elements = ggml_vk_get_conv_elements(dst);
7895
+ } break;
7313
7896
  case GGML_OP_ADD:
7314
7897
  case GGML_OP_SUB:
7315
7898
  case GGML_OP_DIV:
7316
7899
  case GGML_OP_MUL:
7317
7900
  case GGML_OP_SCALE:
7318
7901
  case GGML_OP_SQR:
7902
+ case GGML_OP_SQRT:
7319
7903
  case GGML_OP_SIN:
7320
7904
  case GGML_OP_COS:
7321
7905
  case GGML_OP_CLAMP:
@@ -7354,6 +7938,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7354
7938
  elements = { ne, 1, 1 };
7355
7939
  }
7356
7940
  } break;
7941
+ case GGML_OP_ADD_ID:
7942
+ {
7943
+ elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
7944
+ } break;
7357
7945
  case GGML_OP_SET_ROWS:
7358
7946
  {
7359
7947
  uint32_t ne = ggml_nelements(src0);
@@ -7393,8 +7981,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7393
7981
  }
7394
7982
  }
7395
7983
 
7396
- if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
7397
- // Empty src1 is possible in soft_max, but the shader needs a buffer
7984
+ if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
7985
+ vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X;
7986
+ size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0;
7987
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7988
+ { vk_subbuffer{ d_X, x_buf_offset, x_sz },
7989
+ vk_subbuffer{ d_Y, y_buf_offset, y_sz },
7990
+ vk_subbuffer{ d_D, d_buf_offset, d_sz },
7991
+ vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
7992
+ }, pc, elements);
7993
+ } else if (op == GGML_OP_GLU) {
7994
+ // Empty src1 is possible in glu, but the shader needs a buffer
7398
7995
  vk_subbuffer subbuf_y;
7399
7996
  if (use_src1) {
7400
7997
  subbuf_y = { d_Y, y_buf_offset, y_sz };
@@ -7402,8 +7999,24 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7402
7999
  subbuf_y = { d_X, 0, x_sz };
7403
8000
  }
7404
8001
 
7405
- ggml_vk_sync_buffers(subctx);
7406
8002
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8003
+ } else if (op == GGML_OP_SOFT_MAX) {
8004
+ // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
8005
+ vk_subbuffer subbuf_y;
8006
+ if (use_src1) {
8007
+ subbuf_y = { d_Y, y_buf_offset, y_sz };
8008
+ } else {
8009
+ subbuf_y = { d_X, 0, x_sz };
8010
+ }
8011
+
8012
+ vk_subbuffer subbuf_z;
8013
+ if (use_src2) {
8014
+ subbuf_z = { d_Z, z_buf_offset, z_sz };
8015
+ } else {
8016
+ subbuf_z = { d_X, 0, x_sz };
8017
+ }
8018
+
8019
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7407
8020
  } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
7408
8021
  // Empty src2 is possible in rope, but the shader needs a buffer
7409
8022
  vk_subbuffer subbuf_z;
@@ -7413,26 +8026,23 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7413
8026
  subbuf_z = { d_X, 0, x_sz };
7414
8027
  }
7415
8028
 
7416
- ggml_vk_sync_buffers(subctx);
7417
8029
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7418
8030
  } else if (op == GGML_OP_IM2COL) {
7419
8031
  // im2col uses only src1 and dst buffers
7420
- ggml_vk_sync_buffers(subctx);
7421
8032
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7422
8033
  } else if (op == GGML_OP_COUNT_EQUAL) {
7423
- ggml_vk_sync_buffers(subctx);
7424
8034
  // count_equal assumes that destination buffer is initialized with zeroes
7425
8035
  ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
7426
- ggml_vk_sync_buffers(subctx);
8036
+ ggml_vk_sync_buffers(ctx, subctx);
7427
8037
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8038
+ } else if (op == GGML_OP_OPT_STEP_SGD) {
8039
+ // OPT_STEP_SGD works on src0, it does not need dst
8040
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
7428
8041
  } else if (use_src2) {
7429
- ggml_vk_sync_buffers(subctx);
7430
8042
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7431
8043
  } else if (use_src1) {
7432
- ggml_vk_sync_buffers(subctx);
7433
8044
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7434
8045
  } else {
7435
- ggml_vk_sync_buffers(subctx);
7436
8046
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7437
8047
  }
7438
8048
  }
@@ -7472,6 +8082,116 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
7472
8082
  }, dryrun);
7473
8083
  }
7474
8084
 
8085
+ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
8086
+ const ggml_tensor *first_node = cgraph->nodes[node_idx];
8087
+ const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
8088
+
8089
+ // Make a list of all the tensors used by the op.
8090
+ // Last element of the list is the dest tensor.
8091
+ const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
8092
+ uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
8093
+ uint32_t num_tensors = num_srcs + 1;
8094
+ GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT);
8095
+
8096
+ tensors[0] = first_node->src[0];
8097
+ tensors[1] = first_node->src[1];
8098
+ for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) {
8099
+ // check whether the previous result is src[0] or src[1]
8100
+ if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) {
8101
+ tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1];
8102
+ } else {
8103
+ tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0];
8104
+ }
8105
+ }
8106
+ tensors[num_srcs] = dst;
8107
+
8108
+ vk_op_multi_add_push_constants pc;
8109
+ pc.ne20 = (uint32_t)dst->ne[0];
8110
+ pc.ne21 = (uint32_t)dst->ne[1];
8111
+ pc.ne22 = (uint32_t)dst->ne[2];
8112
+ pc.ne23 = (uint32_t)dst->ne[3];
8113
+
8114
+ for (uint32_t i = 0; i < num_tensors; ++i) {
8115
+ const ggml_tensor *t = tensors[i];
8116
+ pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float);
8117
+ pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float);
8118
+ pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
8119
+ pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
8120
+ }
8121
+ pc.rms_partials = ctx->do_add_rms_partials;
8122
+
8123
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op);
8124
+
8125
+ if (pipeline == nullptr) {
8126
+ std::cerr << "ggml_vulkan: Error: Missing multi_add";
8127
+ GGML_ABORT("fatal error");
8128
+ }
8129
+
8130
+ if (dryrun) {
8131
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
8132
+ return;
8133
+ }
8134
+
8135
+ ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT];
8136
+ vk_buffer buf[MAX_PARAMETER_COUNT];
8137
+ size_t offset[MAX_PARAMETER_COUNT];
8138
+ bool uma[MAX_PARAMETER_COUNT];
8139
+
8140
+ for (uint32_t i = 0; i < num_tensors; ++i) {
8141
+ buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
8142
+ buf[i] = nullptr;
8143
+ offset[i] = 0;
8144
+ uma[i] = false;
8145
+
8146
+ if (ctx->device->uma) {
8147
+ ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
8148
+ uma[i] = buf[i] != nullptr;
8149
+ }
8150
+ if (!uma[i]) {
8151
+ buf[i] = buf_ctx[i]->dev_buffer;
8152
+ offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
8153
+ }
8154
+ GGML_ASSERT(buf[i] != nullptr);
8155
+ }
8156
+ // If any remaining descriptors are unused, just point them at src[0]
8157
+ for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) {
8158
+ buf[i] = buf[0];
8159
+ offset[i] = 0;
8160
+ }
8161
+ if (ctx->do_add_rms_partials) {
8162
+ buf[num_tensors] = ctx->prealloc_add_rms_partials;
8163
+ offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset;
8164
+ }
8165
+
8166
+ std::array<uint32_t, 3> elements;
8167
+
8168
+ uint32_t ne = ggml_nelements(dst);
8169
+ if (ne > 262144) {
8170
+ elements = { 512, 512, CEIL_DIV(ne, 262144) };
8171
+ } else if (ne > 512) {
8172
+ elements = { 512, CEIL_DIV(ne, 512), 1 };
8173
+ } else {
8174
+ elements = { ne, 1, 1 };
8175
+ }
8176
+
8177
+ static_assert(MAX_PARAMETER_COUNT == 12);
8178
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
8179
+ {
8180
+ vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
8181
+ vk_subbuffer{ buf[1], offset[1], VK_WHOLE_SIZE },
8182
+ vk_subbuffer{ buf[2], offset[2], VK_WHOLE_SIZE },
8183
+ vk_subbuffer{ buf[3], offset[3], VK_WHOLE_SIZE },
8184
+ vk_subbuffer{ buf[4], offset[4], VK_WHOLE_SIZE },
8185
+ vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
8186
+ vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
8187
+ vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
8188
+ vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE },
8189
+ vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE },
8190
+ vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE },
8191
+ vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE },
8192
+ }, pc, elements);
8193
+ }
8194
+
7475
8195
  static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7476
8196
  const uint32_t src0_type_size = ggml_type_size(src0->type);
7477
8197
  const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -7483,7 +8203,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
7483
8203
  (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7484
8204
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7485
8205
  0,
7486
- 0.0f, 0.0f, 0,
8206
+ 0.0f, 0.0f, ctx->do_add_rms_partials,
7487
8207
  }, dryrun);
7488
8208
  }
7489
8209
 
@@ -7532,6 +8252,21 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
7532
8252
  }, dryrun);
7533
8253
  }
7534
8254
 
8255
+ static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
8256
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
8257
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
8258
+ const uint32_t src2_type_size = ggml_type_size(src2->type);
8259
+
8260
+ ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, {
8261
+ (uint32_t)dst->ne[0],
8262
+ (uint32_t)dst->ne[1],
8263
+ (uint32_t)src0->nb[1] / src0_type_size,
8264
+ (uint32_t)src0->nb[2] / src0_type_size,
8265
+ (uint32_t)src1->nb[1] / src1_type_size,
8266
+ (uint32_t)src2->nb[1] / src2_type_size,
8267
+ }, dryrun);
8268
+ }
8269
+
7535
8270
  static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
7536
8271
  GGML_ASSERT(version == 6 || version == 7);
7537
8272
  int num_srcs = version == 6 ? 6 : 7;
@@ -7556,8 +8291,6 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
7556
8291
  src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
7557
8292
  }
7558
8293
 
7559
- ggml_vk_sync_buffers(subctx);
7560
-
7561
8294
  vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
7562
8295
  size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
7563
8296
  bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
@@ -7695,8 +8428,6 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
7695
8428
  ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
7696
8429
  ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
7697
8430
 
7698
- ggml_vk_sync_buffers(subctx);
7699
-
7700
8431
  vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
7701
8432
  size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
7702
8433
  bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
@@ -7763,6 +8494,12 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
7763
8494
  );
7764
8495
  }
7765
8496
 
8497
+ static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
8498
+ const size_t n = ggml_nelements(dst->src[0]);
8499
+
8500
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
8501
+ }
8502
+
7766
8503
  static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7767
8504
  int * op_params = (int *)dst->op_params;
7768
8505
 
@@ -7815,6 +8552,10 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
7815
8552
  ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
7816
8553
  }
7817
8554
 
8555
+ static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8556
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun);
8557
+ }
8558
+
7818
8559
  static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7819
8560
  ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
7820
8561
  }
@@ -7882,6 +8623,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
7882
8623
  const uint32_t src1_type_size = ggml_type_size(src1->type);
7883
8624
  const uint32_t dst_type_size = ggml_type_size(dst->type);
7884
8625
 
8626
+ // Skip empty skip_rows operations. For most ops the empty check at the start
8627
+ // of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst
8628
+ // with empty srcs.
8629
+ if (ggml_is_empty(src0) || ggml_is_empty(src1)) {
8630
+ return;
8631
+ }
8632
+
7885
8633
  ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
7886
8634
  (uint32_t)ggml_nelements(src0),
7887
8635
  (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
@@ -7913,19 +8661,39 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
7913
8661
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
7914
8662
  }
7915
8663
 
8664
+ static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
8665
+ const uint32_t ne = (uint32_t)node->ne[0];
8666
+ const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
8667
+ const uint32_t num_partials = CEIL_DIV(ne, denom);
8668
+ return num_partials;
8669
+ }
8670
+
8671
+ static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
8672
+ const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
8673
+ const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
8674
+ return num_bytes;
8675
+ }
8676
+
7916
8677
  static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
7917
8678
  const uint32_t src0_type_size = ggml_type_size(src0->type);
7918
8679
  const uint32_t src1_type_size = ggml_type_size(src1->type);
7919
8680
  const uint32_t dst_type_size = ggml_type_size(dst->type);
7920
8681
 
8682
+ uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
8683
+
7921
8684
  ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
7922
8685
  (uint32_t)ggml_nelements(src0),
7923
8686
  (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7924
8687
  (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7925
8688
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7926
8689
  0,
7927
- op_params[0], 0.0f, 0,
8690
+ op_params[0], 0.0f, (int32_t)param3,
7928
8691
  }, dryrun);
8692
+
8693
+ if (ctx->do_add_rms_partials) {
8694
+ ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
8695
+ ctx->do_add_rms_partials = false;
8696
+ }
7929
8697
  }
7930
8698
 
7931
8699
  static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -7943,8 +8711,12 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
7943
8711
  }
7944
8712
 
7945
8713
  static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8714
+ const float * op_params_f = (const float *)dst->op_params;
8715
+
7946
8716
  const bool swapped = (bool)dst->op_params[1];
7947
8717
  const bool split = src1 != nullptr;
8718
+ const float alpha = op_params_f[2];
8719
+ const float limit = op_params_f[3];
7948
8720
 
7949
8721
  GGML_ASSERT(ggml_is_contiguous(src0));
7950
8722
 
@@ -7958,7 +8730,15 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
7958
8730
 
7959
8731
  const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
7960
8732
 
7961
- ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
8733
+ ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
8734
+ {
8735
+ (uint32_t)ggml_nelements(dst),
8736
+ (uint32_t)src0->ne[0],
8737
+ (uint32_t)dst->ne[0],
8738
+ mode,
8739
+ alpha,
8740
+ limit
8741
+ }, dryrun);
7962
8742
  }
7963
8743
 
7964
8744
  static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -7966,7 +8746,7 @@ static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& sub
7966
8746
  ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
7967
8747
  }
7968
8748
 
7969
- static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8749
+ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
7970
8750
  float * op_params = (float *)dst->op_params;
7971
8751
 
7972
8752
  float scale = op_params[0];
@@ -7988,7 +8768,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7988
8768
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
7989
8769
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
7990
8770
 
7991
- ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
8771
+ ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
7992
8772
  ncols,
7993
8773
  src1 != nullptr ? nrows_y : (uint32_t)0,
7994
8774
  (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
@@ -7998,6 +8778,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7998
8778
  m0, m1,
7999
8779
  n_head_log2,
8000
8780
  nrows_x,
8781
+ src2 != nullptr
8001
8782
  }, dryrun);
8002
8783
  }
8003
8784
 
@@ -8034,7 +8815,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
8034
8815
  (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
8035
8816
  freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
8036
8817
  src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
8037
- sections[0], sections[1], sections[2], sections[3], backprop
8818
+ { sections[0], sections[1], sections[2], sections[3] }, backprop
8038
8819
  }, dryrun);
8039
8820
  }
8040
8821
 
@@ -8043,30 +8824,30 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
8043
8824
 
8044
8825
  uint32_t ncols = src0->ne[0];
8045
8826
 
8046
- uint32_t ncols_pad = 1;
8047
- while (ncols_pad < ncols) {
8048
- ncols_pad *= 2;
8049
- }
8050
-
8051
- GGML_ASSERT(ncols_pad <= 1024);
8052
-
8053
8827
  ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
8054
8828
  ncols,
8055
- ncols_pad,
8056
8829
  op_params[0],
8057
8830
  }, dryrun);
8058
8831
  }
8059
8832
 
8060
8833
  static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8061
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
8834
+ vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
8835
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun);
8062
8836
  }
8063
8837
 
8064
8838
  static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8065
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
8839
+ vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
8840
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun);
8841
+ }
8842
+
8843
+ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8844
+ vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
8845
+ p.weight = 1.0f / (float)src0->ne[0];
8846
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun);
8066
8847
  }
8067
8848
 
8068
8849
  static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8069
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
8850
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun);
8070
8851
  }
8071
8852
 
8072
8853
  static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -8178,13 +8959,13 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
8178
8959
 
8179
8960
  static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
8180
8961
  const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8181
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
8962
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
8182
8963
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
8183
8964
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
8184
8965
 
8185
8966
  GGML_TENSOR_BINARY_OP_LOCALS
8186
8967
 
8187
- GGML_ASSERT(nb00 == sizeof(float));
8968
+ GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
8188
8969
  GGML_ASSERT(nb10 == sizeof(float));
8189
8970
  GGML_ASSERT(nb0 == sizeof(float));
8190
8971
 
@@ -9190,6 +9971,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
9190
9971
  }
9191
9972
  ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
9192
9973
  }
9974
+ if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) {
9975
+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")");
9976
+ // Resize buffer
9977
+ if (ctx->prealloc_add_rms_partials != nullptr) {
9978
+ ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);
9979
+ }
9980
+ ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials);
9981
+ }
9193
9982
  }
9194
9983
 
9195
9984
  static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
@@ -9220,6 +10009,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9220
10009
  return false;
9221
10010
  case GGML_OP_UNARY:
9222
10011
  switch (ggml_get_unary_op(node)) {
10012
+ case GGML_UNARY_OP_EXP:
9223
10013
  case GGML_UNARY_OP_SILU:
9224
10014
  case GGML_UNARY_OP_GELU:
9225
10015
  case GGML_UNARY_OP_GELU_ERF:
@@ -9237,6 +10027,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9237
10027
  case GGML_GLU_OP_GEGLU:
9238
10028
  case GGML_GLU_OP_REGLU:
9239
10029
  case GGML_GLU_OP_SWIGLU:
10030
+ case GGML_GLU_OP_SWIGLU_OAI:
9240
10031
  case GGML_GLU_OP_GEGLU_ERF:
9241
10032
  case GGML_GLU_OP_GEGLU_QUICK:
9242
10033
  break;
@@ -9244,10 +10035,24 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9244
10035
  return false;
9245
10036
  }
9246
10037
  break;
10038
+ case GGML_OP_ADD:
10039
+ {
10040
+ int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
10041
+ if (next_node_idx < cgraph->n_nodes &&
10042
+ cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
10043
+ cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
10044
+ ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
10045
+ ctx->device->add_rms_fusion) {
10046
+ if (dryrun) {
10047
+ ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
10048
+ }
10049
+ ctx->do_add_rms_partials = true;
10050
+ }
10051
+ } break;
9247
10052
  case GGML_OP_REPEAT:
9248
10053
  case GGML_OP_REPEAT_BACK:
9249
10054
  case GGML_OP_GET_ROWS:
9250
- case GGML_OP_ADD:
10055
+ case GGML_OP_ADD_ID:
9251
10056
  case GGML_OP_ACC:
9252
10057
  case GGML_OP_SUB:
9253
10058
  case GGML_OP_MUL:
@@ -9256,6 +10061,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9256
10061
  case GGML_OP_UPSCALE:
9257
10062
  case GGML_OP_SCALE:
9258
10063
  case GGML_OP_SQR:
10064
+ case GGML_OP_SQRT:
9259
10065
  case GGML_OP_SIN:
9260
10066
  case GGML_OP_COS:
9261
10067
  case GGML_OP_CLAMP:
@@ -9281,6 +10087,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9281
10087
  case GGML_OP_ARGSORT:
9282
10088
  case GGML_OP_SUM:
9283
10089
  case GGML_OP_SUM_ROWS:
10090
+ case GGML_OP_MEAN:
9284
10091
  case GGML_OP_ARGMAX:
9285
10092
  case GGML_OP_COUNT_EQUAL:
9286
10093
  case GGML_OP_IM2COL:
@@ -9294,11 +10101,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9294
10101
  case GGML_OP_LEAKY_RELU:
9295
10102
  case GGML_OP_FLASH_ATTN_EXT:
9296
10103
  case GGML_OP_OPT_STEP_ADAMW:
10104
+ case GGML_OP_OPT_STEP_SGD:
9297
10105
  break;
9298
10106
  default:
9299
10107
  std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
9300
10108
  GGML_ABORT("fatal error");
9301
- return false;
9302
10109
  }
9303
10110
 
9304
10111
  vk_context compute_ctx;
@@ -9325,6 +10132,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9325
10132
  case GGML_OP_UPSCALE:
9326
10133
  case GGML_OP_SCALE:
9327
10134
  case GGML_OP_SQR:
10135
+ case GGML_OP_SQRT:
9328
10136
  case GGML_OP_SIN:
9329
10137
  case GGML_OP_COS:
9330
10138
  case GGML_OP_CLAMP:
@@ -9349,6 +10157,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9349
10157
  case GGML_OP_ARGSORT:
9350
10158
  case GGML_OP_SUM:
9351
10159
  case GGML_OP_SUM_ROWS:
10160
+ case GGML_OP_MEAN:
9352
10161
  case GGML_OP_ARGMAX:
9353
10162
  case GGML_OP_COUNT_EQUAL:
9354
10163
  case GGML_OP_IM2COL:
@@ -9358,11 +10167,15 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9358
10167
  case GGML_OP_CONV_2D:
9359
10168
  case GGML_OP_CONV_2D_DW:
9360
10169
  case GGML_OP_LEAKY_RELU:
10170
+ case GGML_OP_OPT_STEP_SGD:
9361
10171
  {
9362
10172
  // These operations all go through ggml_vk_op_f32, so short-circuit and
9363
10173
  // do the only thing needed for the dryrun.
9364
10174
  vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
9365
10175
  ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10176
+ if (node->op == GGML_OP_RMS_NORM) {
10177
+ ctx->do_add_rms_partials = false;
10178
+ }
9366
10179
  return false;
9367
10180
  }
9368
10181
  default:
@@ -9370,6 +10183,80 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9370
10183
  }
9371
10184
  }
9372
10185
 
10186
+ if (!dryrun) {
10187
+ // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers
10188
+ // to synchronize them. This handles most "normal" synchronization when computing the graph, and when
10189
+ // there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers
10190
+ // outside of this logic. When a node uses one of the prealloc buffers for something like
10191
+ // dequantization or split_k, additional synchronization is needed between those passes.
10192
+ bool need_sync = false;
10193
+
10194
+ // Check whether "node" requires synchronization. The node requires synchronization if it
10195
+ // overlaps in memory with another unsynchronized node and at least one of them is a write.
10196
+ // Destination nodes are checked against both the written/read lists. Source nodes are only
10197
+ // checked against the written list. Two nodes overlap in memory if they come from the same
10198
+ // buffer and the tensor or view ranges overlap.
10199
+ auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector<const ggml_tensor *> &unsynced_nodes) -> bool {
10200
+ if (unsynced_nodes.size() == 0) {
10201
+ return false;
10202
+ }
10203
+ auto n_base = vk_tensor_offset(node) + node->view_offs;
10204
+ auto n_size = ggml_nbytes(node);
10205
+ ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context;
10206
+ vk_buffer a_buf = a_buf_ctx->dev_buffer;
10207
+ for (auto &other : unsynced_nodes) {
10208
+ ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context;
10209
+ vk_buffer o_buf = o_buf_ctx->dev_buffer;
10210
+ if (a_buf == o_buf) {
10211
+ auto o_base = vk_tensor_offset(other) + other->view_offs;
10212
+ auto o_size = ggml_nbytes(other);
10213
+
10214
+ if ((o_base <= n_base && n_base < o_base + o_size) ||
10215
+ (n_base <= o_base && o_base < n_base + n_size)) {
10216
+ return true;
10217
+ }
10218
+ }
10219
+ }
10220
+ return false;
10221
+ };
10222
+
10223
+ // For all fused ops, check if the destination node or any of the source
10224
+ // nodes require synchronization.
10225
+ for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) {
10226
+ const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
10227
+ if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
10228
+ need_sync = true;
10229
+ break;
10230
+ }
10231
+ for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
10232
+ if (!cur_node->src[j]) {
10233
+ continue;
10234
+ }
10235
+ if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) {
10236
+ need_sync = true;
10237
+ break;
10238
+ }
10239
+ }
10240
+ }
10241
+ if (need_sync) {
10242
+ ctx->unsynced_nodes_written.clear();
10243
+ ctx->unsynced_nodes_read.clear();
10244
+ ggml_vk_sync_buffers(ctx, compute_ctx);
10245
+ }
10246
+ // Add the last fused node and all fused source nodes to the unsynchronized list.
10247
+ const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
10248
+ ctx->unsynced_nodes_written.push_back(last_node);
10249
+ for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
10250
+ const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
10251
+ for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
10252
+ if (!cur_node->src[j]) {
10253
+ continue;
10254
+ }
10255
+ ctx->unsynced_nodes_read.push_back(cur_node->src[j]);
10256
+ }
10257
+ }
10258
+ }
10259
+
9373
10260
  switch (node->op) {
9374
10261
  case GGML_OP_REPEAT:
9375
10262
  ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
@@ -9388,8 +10275,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9388
10275
 
9389
10276
  break;
9390
10277
  case GGML_OP_ADD:
9391
- ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
9392
-
10278
+ if (ctx->num_additional_fused_ops) {
10279
+ ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx, dryrun);
10280
+ } else {
10281
+ ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
10282
+ }
9393
10283
  break;
9394
10284
  case GGML_OP_SUB:
9395
10285
  ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9402,6 +10292,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9402
10292
  case GGML_OP_DIV:
9403
10293
  ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
9404
10294
 
10295
+ break;
10296
+ case GGML_OP_ADD_ID:
10297
+ ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
10298
+
9405
10299
  break;
9406
10300
  case GGML_OP_CONCAT:
9407
10301
  ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9418,6 +10312,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9418
10312
  case GGML_OP_SQR:
9419
10313
  ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun);
9420
10314
 
10315
+ break;
10316
+ case GGML_OP_SQRT:
10317
+ ggml_vk_sqrt(ctx, compute_ctx, src0, node, dryrun);
10318
+
9421
10319
  break;
9422
10320
  case GGML_OP_SIN:
9423
10321
  ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun);
@@ -9481,6 +10379,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9481
10379
  break;
9482
10380
  case GGML_OP_UNARY:
9483
10381
  switch (ggml_get_unary_op(node)) {
10382
+ case GGML_UNARY_OP_EXP:
9484
10383
  case GGML_UNARY_OP_SILU:
9485
10384
  case GGML_UNARY_OP_GELU:
9486
10385
  case GGML_UNARY_OP_GELU_ERF:
@@ -9499,6 +10398,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9499
10398
  case GGML_GLU_OP_GEGLU:
9500
10399
  case GGML_GLU_OP_REGLU:
9501
10400
  case GGML_GLU_OP_SWIGLU:
10401
+ case GGML_GLU_OP_SWIGLU_OAI:
9502
10402
  case GGML_GLU_OP_GEGLU_ERF:
9503
10403
  case GGML_GLU_OP_GEGLU_QUICK:
9504
10404
  ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9512,7 +10412,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9512
10412
 
9513
10413
  break;
9514
10414
  case GGML_OP_SOFT_MAX:
9515
- ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
10415
+ ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
9516
10416
 
9517
10417
  break;
9518
10418
  case GGML_OP_SOFT_MAX_BACK:
@@ -9538,6 +10438,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9538
10438
  case GGML_OP_SUM_ROWS:
9539
10439
  ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
9540
10440
 
10441
+ break;
10442
+ case GGML_OP_MEAN:
10443
+ ggml_vk_mean(ctx, compute_ctx, src0, node, dryrun);
10444
+
9541
10445
  break;
9542
10446
  case GGML_OP_ARGMAX:
9543
10447
  ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
@@ -9585,7 +10489,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9585
10489
  break;
9586
10490
 
9587
10491
  case GGML_OP_FLASH_ATTN_EXT:
9588
- ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
10492
+ ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node, dryrun);
9589
10493
 
9590
10494
  break;
9591
10495
 
@@ -9602,6 +10506,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9602
10506
  case GGML_OP_OPT_STEP_ADAMW:
9603
10507
  ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
9604
10508
 
10509
+ break;
10510
+
10511
+ case GGML_OP_OPT_STEP_SGD:
10512
+ ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
10513
+
9605
10514
  break;
9606
10515
  default:
9607
10516
  return false;
@@ -9658,10 +10567,12 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
9658
10567
  case GGML_OP_SUB:
9659
10568
  case GGML_OP_MUL:
9660
10569
  case GGML_OP_DIV:
10570
+ case GGML_OP_ADD_ID:
9661
10571
  case GGML_OP_CONCAT:
9662
10572
  case GGML_OP_UPSCALE:
9663
10573
  case GGML_OP_SCALE:
9664
10574
  case GGML_OP_SQR:
10575
+ case GGML_OP_SQRT:
9665
10576
  case GGML_OP_SIN:
9666
10577
  case GGML_OP_COS:
9667
10578
  case GGML_OP_CLAMP:
@@ -9690,6 +10601,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
9690
10601
  case GGML_OP_ARGSORT:
9691
10602
  case GGML_OP_SUM:
9692
10603
  case GGML_OP_SUM_ROWS:
10604
+ case GGML_OP_MEAN:
9693
10605
  case GGML_OP_ARGMAX:
9694
10606
  case GGML_OP_COUNT_EQUAL:
9695
10607
  case GGML_OP_IM2COL:
@@ -9704,11 +10616,12 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
9704
10616
  case GGML_OP_REPEAT:
9705
10617
  case GGML_OP_REPEAT_BACK:
9706
10618
  case GGML_OP_OPT_STEP_ADAMW:
10619
+ case GGML_OP_OPT_STEP_SGD:
9707
10620
  buf = tensor->buffer;
9708
-
9709
10621
  break;
9710
10622
  case GGML_OP_UNARY:
9711
10623
  switch (ggml_get_unary_op(tensor)) {
10624
+ case GGML_UNARY_OP_EXP:
9712
10625
  case GGML_UNARY_OP_SILU:
9713
10626
  case GGML_UNARY_OP_GELU:
9714
10627
  case GGML_UNARY_OP_GELU_ERF:
@@ -9727,6 +10640,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
9727
10640
  case GGML_GLU_OP_GEGLU:
9728
10641
  case GGML_GLU_OP_REGLU:
9729
10642
  case GGML_GLU_OP_SWIGLU:
10643
+ case GGML_GLU_OP_SWIGLU_OAI:
9730
10644
  case GGML_GLU_OP_GEGLU_ERF:
9731
10645
  case GGML_GLU_OP_GEGLU_QUICK:
9732
10646
  buf = tensor->buffer;
@@ -9804,6 +10718,11 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
9804
10718
  ggml_vk_pool_free(ctx, buffer);
9805
10719
  }
9806
10720
  ctx->gc.temp_buffers.clear();
10721
+ ctx->prealloc_y_last_pipeline_used = {};
10722
+
10723
+ ctx->unsynced_nodes_written.clear();
10724
+ ctx->unsynced_nodes_read.clear();
10725
+ ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
9807
10726
 
9808
10727
  ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
9809
10728
  ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
@@ -9839,6 +10758,7 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
9839
10758
  ggml_vk_destroy_buffer(ctx->prealloc_x);
9840
10759
  ggml_vk_destroy_buffer(ctx->prealloc_y);
9841
10760
  ggml_vk_destroy_buffer(ctx->prealloc_split_k);
10761
+ ctx->prealloc_y_last_pipeline_used = nullptr;
9842
10762
 
9843
10763
  for (auto& buffer : ctx->buffer_pool) {
9844
10764
  ggml_vk_destroy_buffer(buffer);
@@ -10259,6 +11179,58 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
10259
11179
  return true;
10260
11180
  }
10261
11181
 
11182
+ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
11183
+
11184
+ const ggml_tensor *first_node = cgraph->nodes[node_idx];
11185
+ if (first_node->op != GGML_OP_ADD) {
11186
+ return 0;
11187
+ }
11188
+
11189
+ if (!ctx->device->multi_add) {
11190
+ return 0;
11191
+ }
11192
+
11193
+ int32_t num_adds = 1;
11194
+ while (node_idx + num_adds < cgraph->n_nodes &&
11195
+ cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD &&
11196
+ num_adds < MAX_FUSED_ADDS) {
11197
+ num_adds++;
11198
+ }
11199
+
11200
+ // The shader currently requires same shapes (but different strides are allowed),
11201
+ // everything f32, and no misalignment
11202
+ for (int32_t i = 0; i < num_adds; ++i) {
11203
+ const ggml_tensor *next_node = cgraph->nodes[node_idx + i];
11204
+ if (!ggml_are_same_shape(first_node, next_node->src[0]) ||
11205
+ !ggml_are_same_shape(first_node, next_node->src[1]) ||
11206
+ next_node->type != GGML_TYPE_F32 ||
11207
+ next_node->src[0]->type != GGML_TYPE_F32 ||
11208
+ next_node->src[1]->type != GGML_TYPE_F32 ||
11209
+ get_misalign_bytes(ctx, next_node) ||
11210
+ get_misalign_bytes(ctx, next_node->src[0]) ||
11211
+ get_misalign_bytes(ctx, next_node->src[1])) {
11212
+ num_adds = i;
11213
+ }
11214
+ }
11215
+
11216
+ // Verify we can fuse these
11217
+ ggml_op adds[MAX_FUSED_ADDS];
11218
+ for (int32_t i = 0; i < num_adds; ++i) {
11219
+ adds[i] = GGML_OP_ADD;
11220
+ }
11221
+
11222
+ // decrease num_adds if they can't all be fused
11223
+ while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) {
11224
+ num_adds--;
11225
+ }
11226
+
11227
+ // a single add is not "fused", so just return zero
11228
+ if (num_adds == 1) {
11229
+ return 0;
11230
+ }
11231
+ return num_adds;
11232
+ }
11233
+
10262
11234
  static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
10263
11235
  VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
10264
11236
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -10270,10 +11242,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
10270
11242
  vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
10271
11243
  }
10272
11244
 
11245
+ ctx->prealloc_size_add_rms_partials = 0;
11246
+ ctx->prealloc_size_add_rms_partials_offset = 0;
11247
+ ctx->do_add_rms_partials = false;
11248
+
10273
11249
  uint64_t total_mat_mul_bytes = 0;
10274
11250
  for (int i = 0; i < cgraph->n_nodes; i++) {
10275
- if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10276
- ctx->num_additional_fused_ops = 1;
11251
+ if (!ctx->device->disable_fusion) {
11252
+ uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
11253
+ if (num_adds) {
11254
+ ctx->num_additional_fused_ops = num_adds - 1;
11255
+ } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
11256
+ ctx->num_additional_fused_ops = 1;
11257
+ }
10277
11258
  }
10278
11259
  ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
10279
11260
  if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
@@ -10330,6 +11311,22 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
10330
11311
  compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
10331
11312
  }
10332
11313
 
11314
+ ctx->prealloc_y_last_pipeline_used = nullptr;
11315
+ ctx->prealloc_y_last_tensor_used = nullptr;
11316
+
11317
+ if (ctx->prealloc_size_add_rms_partials) {
11318
+ if (ctx->compute_ctx.expired()) {
11319
+ compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
11320
+ ctx->compute_ctx = compute_ctx;
11321
+ ggml_vk_ctx_begin(ctx->device, compute_ctx);
11322
+ } else {
11323
+ compute_ctx = ctx->compute_ctx.lock();
11324
+ }
11325
+ // initialize partial sums to zero.
11326
+ ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
11327
+ ggml_vk_sync_buffers(ctx, compute_ctx);
11328
+ }
11329
+
10333
11330
  // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
10334
11331
  // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
10335
11332
  // (and scaled down based on model size, so smaller models submit earlier).
@@ -10348,8 +11345,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
10348
11345
  mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
10349
11346
  }
10350
11347
 
10351
- if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10352
- ctx->num_additional_fused_ops = 1;
11348
+ if (!ctx->device->disable_fusion) {
11349
+ uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
11350
+ if (num_adds) {
11351
+ ctx->num_additional_fused_ops = num_adds - 1;
11352
+ } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
11353
+ ctx->num_additional_fused_ops = 1;
11354
+ }
10353
11355
  }
10354
11356
 
10355
11357
  // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
@@ -10456,10 +11458,10 @@ ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
10456
11458
  ggml_vk_init(ctx, dev_num);
10457
11459
 
10458
11460
  ggml_backend_t vk_backend = new ggml_backend {
10459
- /* .guid = */ ggml_backend_vk_guid(),
10460
- /* .interface = */ ggml_backend_vk_interface,
10461
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
10462
- /* .context = */ ctx,
11461
+ /* .guid = */ ggml_backend_vk_guid(),
11462
+ /* .iface = */ ggml_backend_vk_interface,
11463
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
11464
+ /* .context = */ ctx,
10463
11465
  };
10464
11466
 
10465
11467
  return vk_backend;
@@ -10556,6 +11558,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10556
11558
  switch (op->op) {
10557
11559
  case GGML_OP_UNARY:
10558
11560
  switch (ggml_get_unary_op(op)) {
11561
+ case GGML_UNARY_OP_EXP:
10559
11562
  case GGML_UNARY_OP_GELU:
10560
11563
  case GGML_UNARY_OP_GELU_ERF:
10561
11564
  case GGML_UNARY_OP_GELU_QUICK:
@@ -10570,12 +11573,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10570
11573
  default:
10571
11574
  return false;
10572
11575
  }
10573
- break;
10574
11576
  case GGML_OP_GLU:
10575
11577
  switch (ggml_get_glu_op(op)) {
10576
11578
  case GGML_GLU_OP_GEGLU:
10577
11579
  case GGML_GLU_OP_REGLU:
10578
11580
  case GGML_GLU_OP_SWIGLU:
11581
+ case GGML_GLU_OP_SWIGLU_OAI:
10579
11582
  case GGML_GLU_OP_GEGLU_ERF:
10580
11583
  case GGML_GLU_OP_GEGLU_QUICK:
10581
11584
  return ggml_is_contiguous(op->src[0]) &&
@@ -10585,7 +11588,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10585
11588
  default:
10586
11589
  return false;
10587
11590
  }
10588
- break;
10589
11591
  case GGML_OP_MUL_MAT:
10590
11592
  case GGML_OP_MUL_MAT_ID:
10591
11593
  {
@@ -10621,6 +11623,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10621
11623
  case GGML_TYPE_IQ3_S:
10622
11624
  case GGML_TYPE_IQ4_XS:
10623
11625
  case GGML_TYPE_IQ4_NL:
11626
+ case GGML_TYPE_MXFP4:
10624
11627
  break;
10625
11628
  default:
10626
11629
  return false;
@@ -10648,14 +11651,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10648
11651
  }
10649
11652
 
10650
11653
  return true;
10651
- } break;
11654
+ }
10652
11655
  case GGML_OP_FLASH_ATTN_EXT:
10653
11656
  {
10654
11657
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
10655
11658
  auto device = ggml_vk_get_device(ctx->device);
10656
11659
  bool coopmat2 = device->coopmat2;
10657
- FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
10658
- if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
11660
+ uint32_t HSK = op->src[1]->ne[0];
11661
+ uint32_t HSV = op->src[2]->ne[0];
11662
+ if ((HSK % 8) != 0 || (HSV % 8) != 0) {
11663
+ return false;
11664
+ }
11665
+ if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {
10659
11666
  return false;
10660
11667
  }
10661
11668
  if (op->src[0]->type != GGML_TYPE_F32) {
@@ -10730,11 +11737,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10730
11737
  case GGML_TYPE_IQ3_S:
10731
11738
  case GGML_TYPE_IQ4_XS:
10732
11739
  case GGML_TYPE_IQ4_NL:
11740
+ case GGML_TYPE_MXFP4:
10733
11741
  return true;
10734
11742
  default:
10735
11743
  return false;
10736
11744
  }
10737
- } break;
11745
+ }
10738
11746
  case GGML_OP_SET_ROWS:
10739
11747
  {
10740
11748
  switch (op->type) {
@@ -10751,7 +11759,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10751
11759
  default:
10752
11760
  return false;
10753
11761
  }
10754
- } break;
11762
+ }
10755
11763
  case GGML_OP_CONT:
10756
11764
  case GGML_OP_CPY:
10757
11765
  case GGML_OP_DUP:
@@ -10803,7 +11811,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10803
11811
  return true;
10804
11812
  }
10805
11813
  return false;
10806
- } break;
11814
+ }
10807
11815
  case GGML_OP_REPEAT:
10808
11816
  return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
10809
11817
  case GGML_OP_REPEAT_BACK:
@@ -10828,13 +11836,22 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10828
11836
  return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10829
11837
  (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
10830
11838
  (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
11839
+ case GGML_OP_ADD_ID:
11840
+ return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
11841
+ op->type == GGML_TYPE_F32;
10831
11842
  case GGML_OP_SILU_BACK:
10832
11843
  case GGML_OP_RMS_NORM_BACK:
10833
11844
  case GGML_OP_SQR:
11845
+ case GGML_OP_SQRT:
10834
11846
  case GGML_OP_SIN:
10835
11847
  case GGML_OP_COS:
10836
11848
  case GGML_OP_CLAMP:
11849
+ case GGML_OP_LEAKY_RELU:
11850
+ case GGML_OP_OPT_STEP_ADAMW:
11851
+ case GGML_OP_OPT_STEP_SGD:
10837
11852
  return op->src[0]->type == GGML_TYPE_F32;
11853
+ case GGML_OP_ARGSORT:
11854
+ return op->ne[0] <= max_argsort_cols;
10838
11855
  case GGML_OP_UPSCALE:
10839
11856
  case GGML_OP_ACC:
10840
11857
  case GGML_OP_CONCAT:
@@ -10844,9 +11861,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10844
11861
  case GGML_OP_DIAG_MASK_INF:
10845
11862
  case GGML_OP_SOFT_MAX:
10846
11863
  case GGML_OP_SOFT_MAX_BACK:
10847
- case GGML_OP_ARGSORT:
11864
+ return true;
10848
11865
  case GGML_OP_SUM:
10849
11866
  case GGML_OP_SUM_ROWS:
11867
+ case GGML_OP_MEAN:
11868
+ return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
10850
11869
  case GGML_OP_ARGMAX:
10851
11870
  case GGML_OP_COUNT_EQUAL:
10852
11871
  case GGML_OP_IM2COL:
@@ -10855,8 +11874,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10855
11874
  case GGML_OP_POOL_2D:
10856
11875
  case GGML_OP_RWKV_WKV6:
10857
11876
  case GGML_OP_RWKV_WKV7:
10858
- case GGML_OP_LEAKY_RELU:
10859
- case GGML_OP_OPT_STEP_ADAMW:
10860
11877
  return true;
10861
11878
  case GGML_OP_CONV_TRANSPOSE_1D:
10862
11879
  return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
@@ -10865,14 +11882,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10865
11882
  // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
10866
11883
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
10867
11884
  const vk_device& device = ggml_vk_get_device(ctx->device);
10868
- bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
10869
11885
  // Channel-contiguous format is not supported yet.
10870
- return (op->src[0]->type == GGML_TYPE_F32 &&
11886
+ return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10871
11887
  op->src[1]->type == GGML_TYPE_F32 &&
10872
11888
  op->type == GGML_TYPE_F32 &&
10873
11889
  ggml_is_contiguous(op->src[0]) &&
10874
11890
  ggml_is_contiguous(op->src[1]) &&
10875
- ggml_is_contiguous(op)) && !is_Apple;
11891
+ ggml_is_contiguous(op));
10876
11892
  }
10877
11893
  default:
10878
11894
  return false;
@@ -11147,7 +12163,7 @@ size_t comp_nb[GGML_MAX_DIMS];
11147
12163
  size_t check_counter = 0;
11148
12164
  static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
11149
12165
  ggml_tensor * tensor = cgraph->nodes[tensor_idx];
11150
- if (tensor->op == GGML_OP_TRANSPOSE) {
12166
+ if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
11151
12167
  return;
11152
12168
  }
11153
12169
 
@@ -11246,6 +12262,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11246
12262
  if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
11247
12263
  const float * params = (const float *)tensor->op_params;
11248
12264
  tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
12265
+ if (src_clone[4]) {
12266
+ ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]);
12267
+ }
11249
12268
  } else if (tensor->op == GGML_OP_MUL_MAT) {
11250
12269
  tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
11251
12270
  } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
@@ -11264,12 +12283,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11264
12283
  } else if (tensor->op == GGML_OP_CONCAT) {
11265
12284
  tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
11266
12285
  } else if (tensor->op == GGML_OP_UPSCALE) {
11267
- tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
12286
+ tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
11268
12287
  } else if (tensor->op == GGML_OP_SCALE) {
11269
12288
  const float * params = (const float *)tensor->op_params;
11270
- tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
12289
+ tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
11271
12290
  } else if (tensor->op == GGML_OP_SQR) {
11272
12291
  tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
12292
+ } else if (tensor->op == GGML_OP_SQRT) {
12293
+ tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]);
11273
12294
  } else if (tensor->op == GGML_OP_SIN) {
11274
12295
  tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
11275
12296
  } else if (tensor->op == GGML_OP_COS) {
@@ -11340,6 +12361,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11340
12361
  }
11341
12362
  } else if (tensor->op == GGML_OP_UNARY) {
11342
12363
  switch (ggml_get_unary_op(tensor)) {
12364
+ case GGML_UNARY_OP_EXP:
12365
+ tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
12366
+ break;
11343
12367
  case GGML_UNARY_OP_SILU:
11344
12368
  tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
11345
12369
  break;
@@ -11371,6 +12395,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11371
12395
  } else {
11372
12396
  tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
11373
12397
  }
12398
+ ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2));
12399
+ ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3));
11374
12400
  } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
11375
12401
  if (src1 == nullptr) {
11376
12402
  tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
@@ -11378,8 +12404,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11378
12404
  } else {
11379
12405
  tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
11380
12406
  }
11381
- } else if (tensor->op == GGML_OP_SET_ROWS) {
11382
- tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
11383
12407
  } else if (tensor->op == GGML_OP_CONT) {
11384
12408
  tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
11385
12409
  } else if (tensor->op == GGML_OP_RESHAPE) {
@@ -11399,6 +12423,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11399
12423
  tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
11400
12424
  } else if (tensor->op == GGML_OP_SUM_ROWS) {
11401
12425
  tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
12426
+ } else if (tensor->op == GGML_OP_MEAN) {
12427
+ tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
11402
12428
  } else if (tensor->op == GGML_OP_ARGMAX) {
11403
12429
  tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
11404
12430
  } else if (tensor->op == GGML_OP_COUNT_EQUAL) {
@@ -11453,6 +12479,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11453
12479
  src_clone[0]->flags = src0->flags;
11454
12480
  tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
11455
12481
  src_clone[2], src_clone[3], src_clone[4]);
12482
+ } else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
12483
+ src_clone[0]->flags = src0->flags;
12484
+ tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
12485
+ src_clone[2]);
12486
+ } else if (tensor->op == GGML_OP_ADD_ID) {
12487
+ tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
11456
12488
  }
11457
12489
  else {
11458
12490
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
@@ -11487,14 +12519,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11487
12519
 
11488
12520
  static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
11489
12521
  ggml_tensor * tensor = cgraph->nodes[tensor_idx];
11490
- if (tensor->op == GGML_OP_TRANSPOSE) {
12522
+ if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
11491
12523
  return;
11492
12524
  }
11493
- bool fused_rms_norm_mul = false;
11494
12525
  if (ctx->num_additional_fused_ops == 1 &&
11495
12526
  tensor->op == GGML_OP_RMS_NORM &&
11496
12527
  cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
11497
- fused_rms_norm_mul = true;
11498
12528
  tensor = cgraph->nodes[tensor_idx + 1];
11499
12529
  }
11500
12530
 
@@ -11547,6 +12577,9 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph *
11547
12577
  } else if (tensor->type == GGML_TYPE_F16) {
11548
12578
  correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
11549
12579
  result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
12580
+ } else if (tensor->type == GGML_TYPE_BF16) {
12581
+ correct = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
12582
+ result = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
11550
12583
  } else if (tensor->type == GGML_TYPE_I32) {
11551
12584
  correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
11552
12585
  result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);