@novastera-oss/llamarn 0.2.9 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  22. package/cpp/llama.cpp/CMakePresets.json +11 -0
  23. package/cpp/llama.cpp/CODEOWNERS +1 -0
  24. package/cpp/llama.cpp/README.md +8 -8
  25. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  26. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  27. package/cpp/llama.cpp/common/arg.cpp +62 -1
  28. package/cpp/llama.cpp/common/chat.cpp +37 -20
  29. package/cpp/llama.cpp/common/chat.h +2 -0
  30. package/cpp/llama.cpp/common/common.cpp +22 -6
  31. package/cpp/llama.cpp/common/common.h +22 -4
  32. package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
  33. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
  34. package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
  35. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  36. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  37. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  38. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  41. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
  42. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  46. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  97. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
  99. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
  100. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  131. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
  132. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  138. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  139. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  140. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  142. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  143. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  144. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  173. package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  177. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  178. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  179. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
  180. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  181. package/cpp/llama.cpp/include/llama.h +15 -47
  182. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  183. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  184. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  185. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  186. package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
  187. package/cpp/llama.cpp/src/llama-arch.h +23 -1
  188. package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
  189. package/cpp/llama.cpp/src/llama-batch.h +31 -18
  190. package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
  191. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  192. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  193. package/cpp/llama.cpp/src/llama-context.h +26 -16
  194. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  195. package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
  196. package/cpp/llama.cpp/src/llama-graph.h +184 -122
  197. package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
  198. package/cpp/llama.cpp/src/llama-hparams.h +13 -2
  199. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  200. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  201. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  202. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  203. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  204. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  205. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  206. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
  207. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  208. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  209. package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
  210. package/cpp/llama.cpp/src/llama-model.h +21 -4
  211. package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
  212. package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
  213. package/cpp/llama.cpp/src/llama-vocab.h +43 -0
  214. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  215. package/cpp/llama.cpp/src/unicode.h +2 -0
  216. package/ios/include/chat.h +2 -0
  217. package/ios/include/common.h +22 -4
  218. package/ios/include/llama.h +15 -47
  219. package/ios/libs/llama.xcframework/Info.plist +13 -13
  220. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  221. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  223. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  224. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
  225. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  231. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  232. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  250. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  254. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  255. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  261. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  262. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
  263. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  264. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  265. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  267. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  268. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
  269. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
  270. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  274. package/package.json +4 -4
  275. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  276. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  277. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  278. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  279. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  280. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -224,6 +224,21 @@ enum vk_device_architecture {
224
224
  INTEL_XE2,
225
225
  };
226
226
 
227
+ // HSK x HSV
228
+ enum FaHeadSizes {
229
+ FA_HEAD_SIZE_64,
230
+ FA_HEAD_SIZE_80,
231
+ FA_HEAD_SIZE_96,
232
+ FA_HEAD_SIZE_112,
233
+ FA_HEAD_SIZE_128,
234
+ FA_HEAD_SIZE_192,
235
+ FA_HEAD_SIZE_192_128,
236
+ FA_HEAD_SIZE_256,
237
+ FA_HEAD_SIZE_576_512,
238
+ FA_HEAD_SIZE_UNSUPPORTED,
239
+ FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
240
+ };
241
+
227
242
  static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
228
243
  vk::PhysicalDeviceProperties props = device.getProperties();
229
244
 
@@ -305,7 +320,7 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
305
320
  }
306
321
 
307
322
  struct vk_device_struct {
308
- std::mutex mutex;
323
+ std::recursive_mutex mutex;
309
324
 
310
325
  vk::PhysicalDevice physical_device;
311
326
  vk::PhysicalDeviceProperties properties;
@@ -313,6 +328,7 @@ struct vk_device_struct {
313
328
  uint64_t max_memory_allocation_size;
314
329
  uint64_t suballocation_block_size;
315
330
  bool fp16;
331
+ bool bf16;
316
332
  bool pipeline_robustness;
317
333
  vk::Device device;
318
334
  uint32_t vendor_id;
@@ -410,32 +426,42 @@ struct vk_device_struct {
410
426
  vk_pipeline pipeline_div_norepeat[2][2][2];
411
427
 
412
428
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
413
- vk_pipeline pipeline_upscale_f32;
429
+ vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
414
430
  vk_pipeline pipeline_scale_f32;
415
431
  vk_pipeline pipeline_sqr_f32;
416
432
  vk_pipeline pipeline_sin_f32;
417
433
  vk_pipeline pipeline_cos_f32;
418
434
  vk_pipeline pipeline_clamp_f32;
419
435
  vk_pipeline pipeline_pad_f32;
436
+ vk_pipeline pipeline_roll_f32;
420
437
  vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
421
438
  vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
422
439
  vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
423
440
  vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
424
441
  vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
442
+ vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
425
443
  vk_pipeline pipeline_norm_f32;
426
444
  vk_pipeline pipeline_group_norm_f32;
427
445
  vk_pipeline pipeline_rms_norm_f32;
446
+ vk_pipeline pipeline_rms_norm_mul_f32;
428
447
  vk_pipeline pipeline_rms_norm_back_f32;
429
448
  vk_pipeline pipeline_l2_norm_f32;
430
449
 
431
450
  // [src/dst 0=fp32,1=fp16]
432
451
  vk_pipeline pipeline_gelu[2];
452
+ vk_pipeline pipeline_gelu_erf[2];
433
453
  vk_pipeline pipeline_gelu_quick[2];
434
454
  vk_pipeline pipeline_silu[2];
435
455
  vk_pipeline pipeline_relu[2];
436
456
  vk_pipeline pipeline_tanh[2];
437
457
  vk_pipeline pipeline_sigmoid[2];
438
458
 
459
+ vk_pipeline pipeline_geglu[2];
460
+ vk_pipeline pipeline_reglu[2];
461
+ vk_pipeline pipeline_swiglu[2];
462
+ vk_pipeline pipeline_geglu_erf[2];
463
+ vk_pipeline pipeline_geglu_quick[2];
464
+
439
465
  vk_pipeline pipeline_leaky_relu_f32;
440
466
  vk_pipeline pipeline_silu_back_f32;
441
467
  vk_pipeline pipeline_diag_mask_inf_f32;
@@ -457,30 +483,16 @@ struct vk_device_struct {
457
483
  vk_pipeline pipeline_rwkv_wkv6_f32;
458
484
  vk_pipeline pipeline_rwkv_wkv7_f32;
459
485
  vk_pipeline pipeline_opt_step_adamw_f32;
486
+ vk_pipeline pipeline_conv2d_f32;
460
487
  vk_pipeline pipeline_conv2d_dw_whcn_f32;
461
488
  vk_pipeline pipeline_conv2d_dw_cwhn_f32;
462
489
 
463
490
  // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
464
- vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
465
- vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
466
- vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
467
- vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
468
- vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
469
- vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
470
-
471
- vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
472
- vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
473
- vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
474
- vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
475
- vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
476
- vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
477
-
478
- vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
479
- vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
480
- vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
481
- vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
482
- vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
483
- vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
491
+ vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
492
+
493
+ vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
494
+
495
+ vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
484
496
 
485
497
  vk_pipeline pipeline_flash_attn_split_k_reduce;
486
498
 
@@ -493,6 +505,8 @@ struct vk_device_struct {
493
505
 
494
506
  ggml_backend_buffer_type buffer_type;
495
507
 
508
+ bool disable_fusion;
509
+
496
510
  #ifdef GGML_VULKAN_MEMORY_DEBUG
497
511
  std::unique_ptr<vk_memory_logger> memory_logger;
498
512
  #endif
@@ -627,6 +641,8 @@ struct vk_flash_attn_push_constants {
627
641
  uint32_t nev2;
628
642
  uint32_t nev3;
629
643
  uint32_t nem1;
644
+ uint32_t nem2;
645
+ uint32_t nem3;
630
646
 
631
647
  uint32_t nb01;
632
648
  uint32_t nb02;
@@ -637,14 +653,12 @@ struct vk_flash_attn_push_constants {
637
653
  uint32_t nb21;
638
654
  uint32_t nb22;
639
655
  uint32_t nb23;
640
- uint32_t nb31;
641
656
 
642
657
  float scale;
643
658
  float max_bias;
644
659
  float logit_softcap;
645
660
 
646
- uint32_t mask;
647
- uint32_t n_head_log2;
661
+ uint32_t mask_n_head_log2;
648
662
  float m0;
649
663
  float m1;
650
664
 
@@ -652,6 +666,7 @@ struct vk_flash_attn_push_constants {
652
666
  uint32_t split_kv;
653
667
  uint32_t k_num;
654
668
  };
669
+ static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
655
670
 
656
671
  struct vk_op_push_constants {
657
672
  uint32_t KX;
@@ -660,6 +675,13 @@ struct vk_op_push_constants {
660
675
  float param2;
661
676
  };
662
677
 
678
+ struct vk_op_glu_push_constants {
679
+ uint32_t N;
680
+ uint32_t ne00;
681
+ uint32_t ne20;
682
+ uint32_t mode; // 0: default, 1: swapped, 2: split
683
+ };
684
+
663
685
  struct vk_op_unary_push_constants {
664
686
  uint32_t ne;
665
687
  uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -675,6 +697,37 @@ struct vk_op_unary_push_constants {
675
697
  };
676
698
  static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
677
699
 
700
+ static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
701
+ GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
702
+ ne = ne != 0 ? ne : ggml_nelements(dst);
703
+ GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
704
+
705
+ vk_op_unary_push_constants p{};
706
+ p.ne = (uint32_t)ne;
707
+
708
+ size_t src0_tsize = ggml_type_size(src0->type);
709
+ p.ne00 = (uint32_t)src0->ne[0];
710
+ p.ne01 = (uint32_t)src0->ne[1];
711
+ p.ne02 = (uint32_t)src0->ne[2];
712
+ p.ne03 = (uint32_t)src0->ne[3];
713
+ p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
714
+ p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
715
+ p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
716
+ p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
717
+
718
+ size_t dst_tsize = ggml_type_size(dst->type);
719
+ p.ne10 = (uint32_t)dst->ne[0];
720
+ p.ne11 = (uint32_t)dst->ne[1];
721
+ p.ne12 = (uint32_t)dst->ne[2];
722
+ p.ne13 = (uint32_t)dst->ne[3];
723
+ p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
724
+ p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
725
+ p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
726
+ p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
727
+
728
+ return p; // fastdiv values and offsets are initialized later in ggml_vk_op
729
+ }
730
+
678
731
  // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
679
732
  // Precompute mp (m' in the paper) and L such that division
680
733
  // can be computed using a multiply (high 32b of 64b result)
@@ -743,6 +796,14 @@ struct vk_op_rope_push_constants {
743
796
  struct vk_op_soft_max_push_constants {
744
797
  uint32_t KX;
745
798
  uint32_t KY;
799
+ uint32_t ne00;
800
+ uint32_t ne01;
801
+ uint32_t ne02;
802
+ uint32_t ne12;
803
+ uint32_t ne13;
804
+ uint32_t nb11;
805
+ uint32_t nb12;
806
+ uint32_t nb13;
746
807
  float scale;
747
808
  float max_bias;
748
809
  float m0;
@@ -816,6 +877,38 @@ struct vk_op_rwkv_wkv7_push_constants {
816
877
  uint32_t H;
817
878
  };
818
879
 
880
+ struct vk_op_conv2d_push_constants {
881
+ uint32_t Cout;
882
+ uint32_t Cin;
883
+ uint32_t N;
884
+
885
+ uint32_t KW;
886
+ uint32_t KH;
887
+ uint32_t W;
888
+ uint32_t H;
889
+ uint32_t OW;
890
+ uint32_t OH;
891
+
892
+ uint32_t s0;
893
+ uint32_t s1;
894
+ uint32_t p0;
895
+ uint32_t p1;
896
+ uint32_t d0;
897
+ uint32_t d1;
898
+
899
+ uint32_t nb01;
900
+ uint32_t nb02;
901
+ uint32_t nb03;
902
+
903
+ uint32_t nb11;
904
+ uint32_t nb12;
905
+ uint32_t nb13;
906
+
907
+ uint32_t nb1;
908
+ uint32_t nb2;
909
+ uint32_t nb3;
910
+ };
911
+
819
912
  struct vk_op_conv2d_dw_push_constants {
820
913
  uint32_t ne;
821
914
  uint32_t batches;
@@ -836,6 +929,7 @@ struct vk_op_conv2d_dw_push_constants {
836
929
 
837
930
  struct vk_op_upscale_push_constants {
838
931
  uint32_t ne; uint32_t a_offset; uint32_t d_offset;
932
+ uint32_t ne00; uint32_t ne01;
839
933
  uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
840
934
  uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
841
935
  float sf0; float sf1; float sf2; float sf3;
@@ -914,18 +1008,45 @@ private:
914
1008
  #endif // GGML_VULKAN_MEMORY_DEBUG
915
1009
 
916
1010
  class vk_perf_logger {
917
- public:
1011
+ public:
918
1012
  void print_timings() {
1013
+ if (timings.empty()) {
1014
+ return;
1015
+ }
1016
+ uint64_t total_all_op_times = 0;
919
1017
  std::cerr << "----------------\nVulkan Timings:" << std::endl;
920
- for (const auto& t : timings) {
921
- uint64_t total = 0;
922
- for (const auto& time : t.second) {
923
- total += time;
1018
+ for (const auto & t : timings) {
1019
+ uint64_t total_op_times = 0;
1020
+ for (const auto & time : t.second) {
1021
+ total_op_times += time;
1022
+ }
1023
+ std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
1024
+ << " us";
1025
+
1026
+ // If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
1027
+ auto it = flops.find(t.first);
1028
+ if (it != flops.end() && (it->second).size() == t.second.size()) {
1029
+ uint64_t total_op_flops = 0;
1030
+ for (const auto & elem : it->second) {
1031
+ total_op_flops += elem;
1032
+ }
1033
+ std::cerr << " ("
1034
+ << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) /
1035
+ (double(total_op_times) / (1000.0 * 1000.0 * 1000.0))
1036
+ << " GFLOPS/s)";
924
1037
  }
925
- std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl;
1038
+
1039
+ total_all_op_times += total_op_times;
1040
+
1041
+ std::cerr << std::endl;
1042
+ }
1043
+
1044
+ if (timings.size() > 0) {
1045
+ std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl;
926
1046
  }
927
1047
 
928
1048
  timings.clear();
1049
+ flops.clear();
929
1050
  }
930
1051
 
931
1052
  void log_timing(const ggml_tensor * node, uint64_t time) {
@@ -934,22 +1055,45 @@ public:
934
1055
  return;
935
1056
  }
936
1057
  if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
937
- const uint64_t m = node->src[0]->ne[1];
938
- const uint64_t n = node->src[1]->ne[1];
939
- const uint64_t k = node->src[1]->ne[0];
940
- std::string name = ggml_op_name(node->op);
1058
+ const uint64_t m = node->src[0]->ne[1];
1059
+ const uint64_t n = node->src[1]->ne[1];
1060
+ const uint64_t k = node->src[1]->ne[0];
1061
+ std::string name = ggml_op_name(node->op);
941
1062
  if (n == 1) {
942
1063
  name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k);
943
1064
  } else {
944
1065
  name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
945
1066
  }
946
1067
  timings[name].push_back(time);
1068
+ flops[name].push_back(m * n * (k + (k - 1)));
1069
+ return;
1070
+ }
1071
+ if (node->op == GGML_OP_CONV_2D) {
1072
+ std::string name = ggml_op_name(node->op);
1073
+ ggml_tensor * knl = node->src[0];
1074
+ uint64_t OW = node->ne[0];
1075
+ uint64_t OH = node->ne[1];
1076
+ uint64_t N = node->ne[3];
1077
+ uint64_t Cout = node->ne[2];
1078
+ uint64_t KW = knl->ne[0];
1079
+ uint64_t KH = knl->ne[1];
1080
+ uint64_t Cin = knl->ne[2];
1081
+ // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
1082
+ uint64_t size_M = Cout;
1083
+ uint64_t size_K = Cin * KW * KH;
1084
+ uint64_t size_N = N * OW * OH;
1085
+ uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1));
1086
+ name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
1087
+ ", N=N*OW*OH=" + std::to_string(size_N);
1088
+ flops[name].push_back(n_flops);
1089
+ timings[name].push_back(time);
947
1090
  return;
948
1091
  }
949
1092
  timings[ggml_op_name(node->op)].push_back(time);
950
1093
  }
951
- private:
1094
+ private:
952
1095
  std::map<std::string, std::vector<uint64_t>> timings;
1096
+ std::map<std::string, std::vector<uint64_t>> flops;
953
1097
  };
954
1098
 
955
1099
  struct ggml_backend_vk_context {
@@ -978,6 +1122,10 @@ struct ggml_backend_vk_context {
978
1122
 
979
1123
  vk_command_pool compute_cmd_pool;
980
1124
  vk_command_pool transfer_cmd_pool;
1125
+
1126
+ // number of additional consecutive nodes that are being fused with the
1127
+ // node currently being processed
1128
+ int num_additional_fused_ops {};
981
1129
  };
982
1130
 
983
1131
  static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@@ -1063,8 +1211,8 @@ static size_t vk_skip_checks;
1063
1211
  static size_t vk_output_tensor;
1064
1212
 
1065
1213
  static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
1066
- static void ggml_vk_check_results_0(ggml_tensor * tensor);
1067
- static void ggml_vk_check_results_1(ggml_tensor * tensor);
1214
+ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
1215
+ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
1068
1216
  #endif
1069
1217
 
1070
1218
  typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@@ -1197,7 +1345,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
1197
1345
  }
1198
1346
 
1199
1347
  {
1200
- std::lock_guard<std::mutex> guard(device->mutex);
1348
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
1201
1349
  device->pipelines.insert({ pipeline->name, pipeline });
1202
1350
  }
1203
1351
 
@@ -1411,7 +1559,7 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyPrope
1411
1559
 
1412
1560
  static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
1413
1561
  VK_LOG_DEBUG("ggml_vk_create_queue()");
1414
- std::lock_guard<std::mutex> guard(device->mutex);
1562
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
1415
1563
 
1416
1564
  q.queue_family_index = queue_family_index;
1417
1565
  q.transfer_only = transfer_only;
@@ -1673,10 +1821,46 @@ enum FaCodePath {
1673
1821
  FA_COOPMAT2,
1674
1822
  };
1675
1823
 
1824
+ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
1825
+ if (hsk != 192 && hsk != 576 && hsk != hsv) {
1826
+ return FA_HEAD_SIZE_UNSUPPORTED;
1827
+ }
1828
+ switch (hsk) {
1829
+ case 64: return FA_HEAD_SIZE_64;
1830
+ case 80: return FA_HEAD_SIZE_80;
1831
+ case 96: return FA_HEAD_SIZE_96;
1832
+ case 112: return FA_HEAD_SIZE_112;
1833
+ case 128: return FA_HEAD_SIZE_128;
1834
+ case 192:
1835
+ if (hsv == 192) {
1836
+ return FA_HEAD_SIZE_192;
1837
+ } else if (hsv == 128) {
1838
+ return FA_HEAD_SIZE_192_128;
1839
+ } else {
1840
+ return FA_HEAD_SIZE_UNSUPPORTED;
1841
+ }
1842
+ case 256: return FA_HEAD_SIZE_256;
1843
+ case 576:
1844
+ if (hsv == 512) {
1845
+ return FA_HEAD_SIZE_576_512;
1846
+ } else {
1847
+ return FA_HEAD_SIZE_UNSUPPORTED;
1848
+ }
1849
+ default: return FA_HEAD_SIZE_UNSUPPORTED;
1850
+ }
1851
+ }
1852
+
1676
1853
  // number of rows/cols for flash attention shader
1677
1854
  static constexpr uint32_t flash_attention_num_small_rows = 32;
1678
1855
  static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
1679
- static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
1856
+
1857
+ static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
1858
+ if (hsv >= 512) {
1859
+ return 2;
1860
+ } else {
1861
+ return 8;
1862
+ }
1863
+ }
1680
1864
 
1681
1865
  // The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
1682
1866
  // 128 threads split into four subgroups, each subgroup does 1/4
@@ -1693,14 +1877,15 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
1693
1877
  }
1694
1878
  }
1695
1879
 
1696
- static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
1880
+ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
1697
1881
  GGML_UNUSED(clamp);
1882
+ GGML_UNUSED(hsv);
1698
1883
 
1699
1884
  if (path == FA_SCALAR) {
1700
1885
  if (small_rows) {
1701
1886
  return {scalar_flash_attention_num_small_rows, 64};
1702
1887
  } else {
1703
- return {scalar_flash_attention_num_large_rows, 32};
1888
+ return {get_fa_scalar_num_large_rows(hsv), 32};
1704
1889
  }
1705
1890
  }
1706
1891
 
@@ -1718,8 +1903,12 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
1718
1903
  }
1719
1904
 
1720
1905
  // small cols to reduce register count
1721
- if (ggml_is_quantized(type) || D == 256) {
1722
- return {64, 32};
1906
+ if (ggml_is_quantized(type) || hsk >= 256) {
1907
+ if (hsk >= 512) {
1908
+ return {32, 32};
1909
+ } else {
1910
+ return {64, 32};
1911
+ }
1723
1912
  }
1724
1913
  return {64, 64};
1725
1914
  }
@@ -1761,7 +1950,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1761
1950
  const uint32_t warps = warptile[0] / warptile[10];
1762
1951
 
1763
1952
  const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
1764
- const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
1953
+ const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
1765
1954
  const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
1766
1955
 
1767
1956
  const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
@@ -1886,10 +2075,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
1886
2075
  s_mmq_wg_denoms_k = { 32, 32, 1 };
1887
2076
 
1888
2077
  // spec constants and tile sizes for quant matmul_id
1889
- l_warptile_mmqid = { 256, 128, 64, 16, 0 };
2078
+ l_warptile_mmqid = { 256, 128, 128, 16, 0 };
1890
2079
  m_warptile_mmqid = { 256, 128, 64, 16, 0 };
1891
2080
  s_warptile_mmqid = { 256, 128, 64, 16, 0 };
1892
- l_mmqid_wg_denoms = { 128, 64, 1 };
2081
+ l_mmqid_wg_denoms = { 128, 128, 1 };
1893
2082
  m_mmqid_wg_denoms = { 128, 64, 1 };
1894
2083
  s_mmqid_wg_denoms = { 128, 64, 1 };
1895
2084
 
@@ -2007,23 +2196,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
2007
2196
  }
2008
2197
  compile_count++;
2009
2198
  }
2199
+
2010
2200
  compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
2011
2201
  parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
2012
2202
  };
2013
2203
 
2014
- auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
2015
- return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
2204
+ auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
2205
+ return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
2016
2206
  };
2017
2207
 
2018
- auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
2208
+ auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
2019
2209
  // For large number of rows, 128 invocations seems to work best.
2020
2210
  // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
2021
2211
  // can't use 256 for D==80.
2022
2212
  // For scalar, use 128 (arbitrary)
2213
+ // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
2214
+ const uint32_t D = (hsk|hsv);
2023
2215
  uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
2024
2216
  ? scalar_flash_attention_workgroup_size
2025
2217
  : ((small_rows && (D % 32) == 0) ? 256 : 128);
2026
- auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
2218
+ auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
2027
2219
 
2028
2220
  // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
2029
2221
  // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@@ -2032,26 +2224,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
2032
2224
 
2033
2225
  // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
2034
2226
  GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
2035
- return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
2227
+ return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
2036
2228
  };
2037
2229
 
2038
- #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
2039
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2040
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2041
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2042
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2043
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2044
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2045
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2046
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2230
+ #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
2231
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2232
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2233
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2234
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2235
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2236
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2237
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2238
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2047
2239
 
2048
2240
  #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
2049
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
2050
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
2051
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
2052
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
2053
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
2054
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
2241
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
2242
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
2243
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
2244
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
2245
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
2246
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
2247
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
2248
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
2249
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
2055
2250
 
2056
2251
  CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
2057
2252
  CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
@@ -2641,7 +2836,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2641
2836
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2642
2837
 
2643
2838
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2644
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
2839
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
2645
2840
  ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2646
2841
 
2647
2842
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -2655,7 +2850,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2655
2850
 
2656
2851
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2657
2852
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2658
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2853
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2854
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
2659
2855
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2660
2856
  ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2661
2857
 
@@ -2672,19 +2868,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
2672
2868
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2673
2869
 
2674
2870
  if (device->float_controls_rte_fp16) {
2675
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2676
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2677
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2678
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2679
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2680
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2871
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2872
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2873
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2874
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2875
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2876
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2877
+ } else {
2878
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2879
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2880
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2881
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2882
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2883
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2884
+ }
2885
+
2886
+ if (device->float_controls_rte_fp16) {
2887
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2888
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2889
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2890
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2891
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2892
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2893
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2894
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2895
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2681
2896
  } else {
2682
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2683
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2684
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2685
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2686
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2687
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2897
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2898
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2899
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2900
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2901
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2902
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2903
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2904
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2905
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2688
2906
  }
2689
2907
 
2690
2908
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
@@ -2702,10 +2920,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
2702
2920
  return s;
2703
2921
  };
2704
2922
 
2923
+ bool rte = device->float_controls_rte_fp16;
2705
2924
  #define CREATE_BINARY(name, namemod, spec) \
2706
2925
  for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
2707
2926
  ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
2708
- #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
2927
+ #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
2709
2928
  "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
2710
2929
 
2711
2930
  CREATE_BINARY(add, , {0})
@@ -2724,7 +2943,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
2724
2943
  ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2725
2944
  ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2726
2945
 
2727
- ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
2946
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
2947
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
2948
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
2728
2949
 
2729
2950
  ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2730
2951
 
@@ -2736,6 +2957,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2736
2957
 
2737
2958
  ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2738
2959
 
2960
+ ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2961
+
2739
2962
  ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2740
2963
  ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2741
2964
 
@@ -2744,6 +2967,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2744
2967
  ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2745
2968
 
2746
2969
  CREATE_UNARY(gelu)
2970
+ CREATE_UNARY(gelu_erf)
2747
2971
  CREATE_UNARY(gelu_quick)
2748
2972
  CREATE_UNARY(silu)
2749
2973
  CREATE_UNARY(relu)
@@ -2751,6 +2975,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
2751
2975
  CREATE_UNARY(sigmoid)
2752
2976
  #undef CREATE_UNARY
2753
2977
 
2978
+ #define CREATE_GLU(name) \
2979
+ if (device->float_controls_rte_fp16) { \
2980
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2981
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2982
+ } else { \
2983
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2984
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2985
+ }
2986
+
2987
+ CREATE_GLU(geglu)
2988
+ CREATE_GLU(reglu)
2989
+ CREATE_GLU(swiglu)
2990
+ CREATE_GLU(geglu_erf)
2991
+ CREATE_GLU(geglu_quick)
2992
+ #undef CREATE_GLU
2993
+
2754
2994
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2755
2995
  ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2756
2996
 
@@ -2806,6 +3046,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
2806
3046
 
2807
3047
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2808
3048
 
3049
+ // conv2d
3050
+ uint32_t conv2d_WG_SIZE = 256;
3051
+ uint32_t conv2d_BS_K = 128;
3052
+ uint32_t conv2d_BS_CRS = 16;
3053
+ uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
3054
+ if (device->subgroup_shuffle &&
3055
+ device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316
3056
+ use_collectives = 1;
3057
+ conv2d_BS_CRS = std::min(
3058
+ device->subgroup_size,
3059
+ conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used.
3060
+ }
3061
+ uint32_t conv2d_BS_NPQ = 128;
3062
+ uint32_t conv2d_TS_K = 8;
3063
+ uint32_t conv2d_shmem_req =
3064
+ (conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float);
3065
+ if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
3066
+ conv2d_BS_CRS = 8;
3067
+ if (use_collectives) {
3068
+ conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
3069
+ }
3070
+ }
3071
+
3072
+ if (use_collectives) {
3073
+ ggml_vk_create_pipeline(
3074
+ device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3075
+ sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3076
+ { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
3077
+ } else {
3078
+ ggml_vk_create_pipeline(
3079
+ device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3080
+ sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3081
+ { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
3082
+ false);
3083
+ }
3084
+
2809
3085
  ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2810
3086
  ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2811
3087
 
@@ -3118,6 +3394,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
3118
3394
 
3119
3395
  device->fp16 = device->fp16 && vk12_features.shaderFloat16;
3120
3396
 
3397
+ #if defined(VK_KHR_shader_bfloat16)
3398
+ device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
3399
+ #else
3400
+ device->bf16 = false;
3401
+ #endif
3402
+
3121
3403
  device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
3122
3404
 
3123
3405
  if (device->subgroup_size_control) {
@@ -3431,6 +3713,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
3431
3713
 
3432
3714
  device->idx = idx;
3433
3715
 
3716
+ device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
3717
+
3434
3718
  return device;
3435
3719
  }
3436
3720
 
@@ -3458,6 +3742,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
3458
3742
  bool coopmat_support = false;
3459
3743
  bool coopmat2_support = false;
3460
3744
  bool integer_dot_product = false;
3745
+ bool bfloat16_support = false;
3461
3746
 
3462
3747
  for (auto properties : ext_props) {
3463
3748
  if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@@ -3478,6 +3763,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
3478
3763
  } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
3479
3764
  !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
3480
3765
  integer_dot_product = true;
3766
+ #endif
3767
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3768
+ } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
3769
+ !getenv("GGML_VK_DISABLE_BFLOAT16")) {
3770
+ bfloat16_support = true;
3481
3771
  #endif
3482
3772
  }
3483
3773
  }
@@ -3544,10 +3834,25 @@ static void ggml_vk_print_gpu_info(size_t idx) {
3544
3834
  last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
3545
3835
  }
3546
3836
 
3837
+ #if defined(VK_KHR_shader_bfloat16)
3838
+ VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
3839
+ bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
3840
+ if (bfloat16_support) {
3841
+ last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
3842
+ last_struct = (VkBaseOutStructure *)&bfloat16_features;
3843
+ }
3844
+ #endif
3845
+
3547
3846
  vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
3548
3847
 
3549
3848
  fp16 = fp16 && vk12_features.shaderFloat16;
3550
3849
 
3850
+ #if defined(VK_KHR_shader_bfloat16)
3851
+ bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
3852
+ #else
3853
+ bool bf16 = false;
3854
+ #endif
3855
+
3551
3856
  uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
3552
3857
  const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
3553
3858
  const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
@@ -3565,8 +3870,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
3565
3870
  std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
3566
3871
 
3567
3872
  std::string device_name = props2.properties.deviceName.data();
3568
- GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
3569
- idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
3873
+ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
3874
+ idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,
3570
3875
  props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
3571
3876
 
3572
3877
  if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
@@ -3651,7 +3956,6 @@ static void ggml_vk_instance_init() {
3651
3956
 
3652
3957
  }
3653
3958
 
3654
- size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
3655
3959
  vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
3656
3960
 
3657
3961
  // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
@@ -4124,6 +4428,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
4124
4428
  return nullptr;
4125
4429
  }
4126
4430
 
4431
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
4127
4432
  device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
4128
4433
 
4129
4434
  return buf->ptr;
@@ -4134,6 +4439,8 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
4134
4439
  return;
4135
4440
  }
4136
4441
  VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
4442
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
4443
+
4137
4444
  vk_buffer buf;
4138
4445
  size_t index;
4139
4446
  for (size_t i = 0; i < device->pinned_memory.size(); i++) {
@@ -4156,6 +4463,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
4156
4463
  }
4157
4464
 
4158
4465
  static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
4466
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
4159
4467
  buf = nullptr;
4160
4468
  buf_offset = 0;
4161
4469
  for (size_t i = 0; i < device->pinned_memory.size(); i++) {
@@ -4457,7 +4765,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
4457
4765
  memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
4458
4766
  }
4459
4767
  } else {
4460
- std::lock_guard<std::mutex> guard(dst->device->mutex);
4768
+ std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
4461
4769
 
4462
4770
  vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
4463
4771
  ggml_vk_ctx_begin(dst->device, subctx);
@@ -4548,7 +4856,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
4548
4856
 
4549
4857
  memcpy(dst, (uint8_t *) src->ptr + offset, size);
4550
4858
  } else {
4551
- std::lock_guard<std::mutex> guard(src->device->mutex);
4859
+ std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
4552
4860
 
4553
4861
  vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
4554
4862
  ggml_vk_ctx_begin(src->device, subctx);
@@ -4578,7 +4886,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds
4578
4886
 
4579
4887
  static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
4580
4888
  if (src->device == dst->device) {
4581
- std::lock_guard<std::mutex> guard(src->device->mutex);
4889
+ std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
4582
4890
  VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
4583
4891
  // Copy within the device
4584
4892
  vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
@@ -4613,7 +4921,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t
4613
4921
  static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
4614
4922
  VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
4615
4923
 
4616
- std::lock_guard<std::mutex> guard(dst->device->mutex);
4924
+ std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
4617
4925
  vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
4618
4926
  ggml_vk_ctx_begin(dst->device, subctx);
4619
4927
  subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
@@ -4762,7 +5070,7 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
4762
5070
  return
4763
5071
  tensor->nb[0] == ggml_type_size(tensor->type) &&
4764
5072
  tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
4765
- tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
5073
+ (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]);
4766
5074
  }
4767
5075
 
4768
5076
  static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
@@ -4840,9 +5148,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
4840
5148
  // type size must be exactly 2 or 4.
4841
5149
  GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
4842
5150
  if ((ggml_type_size(src->type) % 4) == 0) {
4843
- return ctx->device->pipeline_contig_cpy_f32_f32;
5151
+ if (contig) {
5152
+ return ctx->device->pipeline_contig_cpy_f32_f32;
5153
+ } else {
5154
+ return ctx->device->pipeline_cpy_f32_f32;
5155
+ }
4844
5156
  } else {
4845
- return ctx->device->pipeline_contig_cpy_f16_f16;
5157
+ if (contig) {
5158
+ return ctx->device->pipeline_contig_cpy_f16_f16;
5159
+ } else {
5160
+ return ctx->device->pipeline_cpy_f16_f16;
5161
+ }
4846
5162
  }
4847
5163
  }
4848
5164
 
@@ -4903,7 +5219,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4903
5219
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
4904
5220
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
4905
5221
  std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
4906
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
5222
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
4907
5223
  GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
4908
5224
 
4909
5225
  const uint64_t ne00 = src0->ne[0];
@@ -5131,7 +5447,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
5131
5447
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
5132
5448
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
5133
5449
  std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)");
5134
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
5450
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
5135
5451
  GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
5136
5452
 
5137
5453
  const uint64_t ne00 = src0->ne[0];
@@ -5732,7 +6048,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
5732
6048
  std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
5733
6049
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
5734
6050
  std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
5735
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
6051
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
5736
6052
  GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
5737
6053
  GGML_ASSERT(ids->type == GGML_TYPE_I32);
5738
6054
 
@@ -5926,14 +6242,60 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
5926
6242
  if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
5927
6243
  ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
5928
6244
  } else {
5929
- ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
6245
+ // Split based on number of ids, to fit in shared memory
6246
+ const uint32_t nei0 = (uint32_t)src2->ne[0];
6247
+ const uint32_t nei1 = (uint32_t)src2->ne[1];
6248
+
6249
+ GGML_ASSERT(nei0 <= 4096);
6250
+ const uint32_t split_size = std::min(nei1, 4096u / nei0);
6251
+
6252
+ ggml_tensor src1_copy = *src1;
6253
+ ggml_tensor src2_copy = *src2;
6254
+ ggml_tensor dst_copy = *dst;
6255
+
6256
+ for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
6257
+ const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
6258
+
6259
+ src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
6260
+ src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
6261
+ dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
6262
+
6263
+ src1_copy.ne[2] = n_tokens;
6264
+ src2_copy.ne[1] = n_tokens;
6265
+ dst_copy.ne[2] = n_tokens;
6266
+
6267
+ ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
6268
+ }
5930
6269
  }
5931
6270
  }
5932
6271
 
5933
- static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
6272
+ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
6273
+ // Needs to be kept up to date on shader changes
6274
+ GGML_UNUSED(hsv);
6275
+ const uint32_t wg_size = scalar_flash_attention_workgroup_size;
6276
+ const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
6277
+ const uint32_t Bc = scalar_flash_attention_Bc;
6278
+
6279
+ const uint32_t tmpsh = wg_size * sizeof(float);
6280
+ const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
6281
+
6282
+ const uint32_t masksh = Bc * Br * sizeof(float);
6283
+
6284
+ const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
6285
+
6286
+ const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
6287
+ const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
6288
+
6289
+ VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
6290
+
6291
+ return supported;
6292
+ }
6293
+
6294
+ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
5934
6295
  // Needs to be kept up to date on shader changes
6296
+ GGML_UNUSED(hsv);
5935
6297
  const uint32_t wg_size = scalar_flash_attention_workgroup_size;
5936
- const uint32_t Br = scalar_flash_attention_num_large_rows;
6298
+ const uint32_t Br = coopmat1_flash_attention_num_large_rows;
5937
6299
  const uint32_t Bc = scalar_flash_attention_Bc;
5938
6300
 
5939
6301
  const uint32_t acctype = f32acc ? 4 : 2;
@@ -5942,12 +6304,12 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
5942
6304
  const uint32_t tmpsh = wg_size * sizeof(float);
5943
6305
  const uint32_t tmpshv4 = wg_size * 4 * acctype;
5944
6306
 
5945
- const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
6307
+ const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
5946
6308
 
5947
- const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
6309
+ const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
5948
6310
  const uint32_t sfsh = Bc * sfshstride * acctype;
5949
6311
 
5950
- const uint32_t kshstride = D / 4 + 2;
6312
+ const uint32_t kshstride = hsk / 4 + 2;
5951
6313
  const uint32_t ksh = Bc * kshstride * f16vec4;
5952
6314
 
5953
6315
  const uint32_t slope = Br * sizeof(float);
@@ -5955,7 +6317,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
5955
6317
  const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
5956
6318
  const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
5957
6319
 
5958
- VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
6320
+ VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
5959
6321
 
5960
6322
  return supported;
5961
6323
  }
@@ -5977,13 +6339,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5977
6339
  GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
5978
6340
 
5979
6341
  const uint32_t nem1 = mask ? mask->ne[1] : 0;
5980
- const uint32_t nbm1 = mask ? mask->nb[1] : 0;
6342
+ const uint32_t nem2 = mask ? mask->ne[2] : 0;
6343
+ const uint32_t nem3 = mask ? mask->ne[3] : 0;
5981
6344
 
5982
- const uint32_t D = neq0;
6345
+ const uint32_t HSK = nek0;
6346
+ const uint32_t HSV = nev0;
5983
6347
  uint32_t N = neq1;
5984
6348
  const uint32_t KV = nek1;
5985
6349
 
5986
- GGML_ASSERT(ne0 == D);
6350
+ GGML_ASSERT(ne0 == HSV);
5987
6351
  GGML_ASSERT(ne2 == N);
5988
6352
 
5989
6353
  // input tensor rows must be contiguous
@@ -5991,12 +6355,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5991
6355
  GGML_ASSERT(nbk0 == ggml_type_size(k->type));
5992
6356
  GGML_ASSERT(nbv0 == ggml_type_size(v->type));
5993
6357
 
5994
- GGML_ASSERT(neq0 == D);
5995
- GGML_ASSERT(nek0 == D);
5996
- GGML_ASSERT(nev0 == D);
6358
+ GGML_ASSERT(neq0 == HSK);
5997
6359
 
5998
6360
  GGML_ASSERT(neq1 == N);
5999
- GGML_ASSERT(nev0 == D);
6000
6361
 
6001
6362
  GGML_ASSERT(nev1 == nek1);
6002
6363
 
@@ -6017,7 +6378,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6017
6378
  const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
6018
6379
  (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
6019
6380
 
6020
- const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
6381
+ const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
6021
6382
 
6022
6383
  if (!coopmat_shape_supported || !coopmat_shmem_supported) {
6023
6384
  path = FA_SCALAR;
@@ -6037,7 +6398,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6037
6398
  case FA_SCALAR:
6038
6399
  case FA_COOPMAT1:
6039
6400
  // We may switch from coopmat1 to scalar, so use the scalar limit for both
6040
- max_gqa = scalar_flash_attention_num_large_rows;
6401
+ max_gqa = get_fa_scalar_num_large_rows(HSV);
6041
6402
  break;
6042
6403
  case FA_COOPMAT2:
6043
6404
  max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
@@ -6047,7 +6408,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6047
6408
  }
6048
6409
 
6049
6410
  if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
6050
- qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
6411
+ qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
6051
6412
  // grouped query attention - make the N dimension equal to gqa_ratio, reduce
6052
6413
  // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
6053
6414
  // and change addressing calculations to index Q's dimension 2.
@@ -6070,47 +6431,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6070
6431
  path = FA_SCALAR;
6071
6432
  }
6072
6433
 
6434
+ // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
6435
+ if (path == FA_SCALAR &&
6436
+ !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
6437
+ small_rows = true;
6438
+ }
6439
+
6073
6440
  bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
6074
6441
 
6442
+ FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
6443
+
6075
6444
  switch (path) {
6076
6445
  case FA_SCALAR:
6077
- switch (D) {
6078
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
6079
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
6080
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
6081
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
6082
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
6083
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
6084
- default:
6085
- GGML_ASSERT(!"unsupported D value");
6086
- return;
6087
- }
6446
+ pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
6088
6447
  break;
6089
6448
  case FA_COOPMAT1:
6090
- switch (D) {
6091
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
6092
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
6093
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
6094
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
6095
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
6096
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
6097
- default:
6098
- GGML_ASSERT(!"unsupported D value");
6099
- return;
6100
- }
6449
+ pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
6101
6450
  break;
6102
6451
  case FA_COOPMAT2:
6103
- switch (D) {
6104
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
6105
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
6106
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
6107
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
6108
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
6109
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
6110
- default:
6111
- GGML_ASSERT(!"unsupported D value");
6112
- return;
6113
- }
6452
+ pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
6114
6453
  break;
6115
6454
  default:
6116
6455
  GGML_ASSERT(0);
@@ -6138,21 +6477,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6138
6477
  const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
6139
6478
 
6140
6479
  // Try to use split_k when KV is large enough to be worth the overhead
6141
- if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
6480
+ if (workgroups_x == 1 && shader_core_count > 0) {
6142
6481
  // Try to run two workgroups per SM.
6143
- split_k = ctx->device->shader_core_count * 2 / workgroups_y;
6482
+ split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
6144
6483
  if (split_k > 1) {
6145
6484
  // Try to evenly split KV into split_k chunks, but it needs to be a multiple
6146
6485
  // of "align", so recompute split_k based on that.
6147
- split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
6486
+ split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
6148
6487
  split_k = CEIL_DIV(KV, split_kv);
6149
6488
  workgroups_x = split_k;
6150
6489
  }
6151
6490
  }
6152
6491
 
6153
- // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
6154
- // and the per-row m and L values (ne1 rows).
6155
- const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
6492
+ // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
6493
+ // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
6494
+ const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
6156
6495
  if (split_k_size > ctx->device->max_memory_allocation_size) {
6157
6496
  GGML_ABORT("Requested preallocation size is too large");
6158
6497
  }
@@ -6239,18 +6578,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6239
6578
  }
6240
6579
  }
6241
6580
 
6581
+ uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
6582
+
6242
6583
  const vk_flash_attn_push_constants pc = { N, KV,
6243
6584
  (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
6244
6585
  (uint32_t)neq2, (uint32_t)neq3,
6245
6586
  (uint32_t)nek2, (uint32_t)nek3,
6246
6587
  (uint32_t)nev2, (uint32_t)nev3,
6247
- nem1,
6588
+ nem1, nem2, nem3,
6248
6589
  q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
6249
6590
  k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
6250
6591
  v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
6251
- nbm1,
6252
6592
  scale, max_bias, logit_softcap,
6253
- mask != nullptr, n_head_log2, m0, m1,
6593
+ mask_n_head_log2, m0, m1,
6254
6594
  gqa_ratio, split_kv, split_k };
6255
6595
 
6256
6596
  ggml_vk_sync_buffers(subctx);
@@ -6271,13 +6611,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6271
6611
  pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
6272
6612
 
6273
6613
  ggml_vk_sync_buffers(subctx);
6274
- const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
6614
+ const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
6275
6615
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
6276
6616
  {
6277
6617
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6278
6618
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6279
6619
  },
6280
- pc2, { (uint32_t)ne1, 1, 1 });
6620
+ pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
6281
6621
  } else {
6282
6622
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6283
6623
  {
@@ -6353,8 +6693,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6353
6693
  }
6354
6694
  return nullptr;
6355
6695
  case GGML_OP_UPSCALE:
6356
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
6357
- return ctx->device->pipeline_upscale_f32;
6696
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6697
+ int mode = ggml_get_op_params_i32(dst, 0);
6698
+ switch (mode) {
6699
+ case GGML_SCALE_MODE_NEAREST:
6700
+ return ctx->device->pipeline_upscale_nearest_f32;
6701
+ case GGML_SCALE_MODE_BILINEAR:
6702
+ return ctx->device->pipeline_upscale_bilinear_f32;
6703
+ case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
6704
+ return ctx->device->pipeline_upscale_bilinear_ac_f32;
6705
+ }
6358
6706
  }
6359
6707
  return nullptr;
6360
6708
  case GGML_OP_SCALE:
@@ -6387,6 +6735,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6387
6735
  return ctx->device->pipeline_pad_f32;
6388
6736
  }
6389
6737
  return nullptr;
6738
+ case GGML_OP_ROLL:
6739
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6740
+ return ctx->device->pipeline_roll_f32;
6741
+ }
6742
+ return nullptr;
6390
6743
  case GGML_OP_REPEAT:
6391
6744
  if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
6392
6745
  return ctx->device->pipeline_repeat_f32;
@@ -6401,6 +6754,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6401
6754
  case GGML_OP_CONT:
6402
6755
  case GGML_OP_DUP:
6403
6756
  return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
6757
+ case GGML_OP_SET_ROWS:
6758
+ return ctx->device->pipeline_set_rows[dst->type];
6404
6759
  case GGML_OP_SILU_BACK:
6405
6760
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6406
6761
  return ctx->device->pipeline_silu_back_f32;
@@ -6418,7 +6773,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6418
6773
  return nullptr;
6419
6774
  case GGML_OP_RMS_NORM:
6420
6775
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6421
- return ctx->device->pipeline_rms_norm_f32;
6776
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
6422
6777
  }
6423
6778
  return nullptr;
6424
6779
  case GGML_OP_RMS_NORM_BACK:
@@ -6443,6 +6798,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6443
6798
  return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
6444
6799
  case GGML_UNARY_OP_GELU:
6445
6800
  return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
6801
+ case GGML_UNARY_OP_GELU_ERF:
6802
+ return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
6446
6803
  case GGML_UNARY_OP_GELU_QUICK:
6447
6804
  return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
6448
6805
  case GGML_UNARY_OP_RELU:
@@ -6455,6 +6812,28 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6455
6812
  break;
6456
6813
  }
6457
6814
  return nullptr;
6815
+ case GGML_OP_GLU:
6816
+ if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6817
+ (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
6818
+ (src0->type != dst->type)) {
6819
+ return nullptr;
6820
+ }
6821
+
6822
+ switch (ggml_get_glu_op(dst)) {
6823
+ case GGML_GLU_OP_GEGLU:
6824
+ return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
6825
+ case GGML_GLU_OP_REGLU:
6826
+ return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
6827
+ case GGML_GLU_OP_SWIGLU:
6828
+ return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
6829
+ case GGML_GLU_OP_GEGLU_ERF:
6830
+ return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
6831
+ case GGML_GLU_OP_GEGLU_QUICK:
6832
+ return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
6833
+ default:
6834
+ break;
6835
+ }
6836
+ return nullptr;
6458
6837
  case GGML_OP_DIAG_MASK_INF:
6459
6838
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6460
6839
  return ctx->device->pipeline_diag_mask_inf_f32;
@@ -6578,6 +6957,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6578
6957
  return ctx->device->pipeline_leaky_relu_f32;
6579
6958
  }
6580
6959
  return nullptr;
6960
+ case GGML_OP_CONV_2D:
6961
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
6962
+ ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
6963
+ return ctx->device->pipeline_conv2d_f32;
6964
+ }
6965
+ return nullptr;
6581
6966
  case GGML_OP_CONV_2D_DW:
6582
6967
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6583
6968
  if (ggml_is_contiguous(src1)) {
@@ -6615,6 +7000,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
6615
7000
  case GGML_OP_RMS_NORM:
6616
7001
  case GGML_OP_CONV_2D_DW:
6617
7002
  case GGML_OP_IM2COL:
7003
+ case GGML_OP_SET_ROWS:
6618
7004
  return true;
6619
7005
  default:
6620
7006
  return false;
@@ -6899,6 +7285,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6899
7285
  const uint32_t OW = dst->ne[0];
6900
7286
  elements = { N * OC * OH * OW, 1, 1};
6901
7287
  } break;
7288
+ case GGML_OP_CONV_2D:
7289
+ {
7290
+ // src0 - kernel: [KW, KH, Cin, Cout]
7291
+ // src1 - input: [W, H, Cin, N]
7292
+ // dst - result: [OW, OH, Cout, N]
7293
+
7294
+ // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
7295
+ auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
7296
+ return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
7297
+ };
7298
+ // parallelize in {OW/BS_K, OH/BS_NPQ, 1}
7299
+ int64_t W = src1->ne[0];
7300
+ int64_t H = src1->ne[1];
7301
+ int64_t KW = src0->ne[0];
7302
+ int64_t KH = src0->ne[1];
7303
+ int64_t Cout = src0->ne[3];
7304
+ int64_t N = src1->ne[3];
7305
+ int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
7306
+ int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
7307
+ int64_t NPQ = N * OW * OH;
7308
+
7309
+ // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
7310
+ elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
7311
+ }
7312
+ break;
6902
7313
  case GGML_OP_ADD:
6903
7314
  case GGML_OP_SUB:
6904
7315
  case GGML_OP_DIV:
@@ -6909,12 +7320,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6909
7320
  case GGML_OP_COS:
6910
7321
  case GGML_OP_CLAMP:
6911
7322
  case GGML_OP_PAD:
7323
+ case GGML_OP_ROLL:
6912
7324
  case GGML_OP_REPEAT:
6913
7325
  case GGML_OP_REPEAT_BACK:
6914
7326
  case GGML_OP_CPY:
6915
7327
  case GGML_OP_CONCAT:
6916
7328
  case GGML_OP_UPSCALE:
6917
7329
  case GGML_OP_UNARY:
7330
+ case GGML_OP_GLU:
6918
7331
  case GGML_OP_CONV_2D_DW:
6919
7332
  {
6920
7333
  uint32_t ne = ggml_nelements(dst);
@@ -6927,6 +7340,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6927
7340
  ne *= ggml_type_size(src0->type) / 2;
6928
7341
  }
6929
7342
  }
7343
+ // copy_to_quant has block size of 32, and each thread does QUANT_K elements.
7344
+ // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
7345
+ // So divide by block size here before splitting into 512x512 groups.
7346
+ if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
7347
+ ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
7348
+ }
6930
7349
  if (ne > 262144) {
6931
7350
  elements = { 512, 512, CEIL_DIV(ne, 262144) };
6932
7351
  } else if (ne > 512) {
@@ -6935,6 +7354,25 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6935
7354
  elements = { ne, 1, 1 };
6936
7355
  }
6937
7356
  } break;
7357
+ case GGML_OP_SET_ROWS:
7358
+ {
7359
+ uint32_t ne = ggml_nelements(src0);
7360
+ if (ggml_is_quantized(dst->type)) {
7361
+ // quants run 32 threads each doing QUANT_K elements
7362
+ ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
7363
+ } else {
7364
+ // scalar types do one element per thread, running 512 threads
7365
+ ne = CEIL_DIV(ne, 512);
7366
+ }
7367
+ if (ne > 262144) {
7368
+ elements = { 512, 512, CEIL_DIV(ne, 262144) };
7369
+ } else if (ne > 512) {
7370
+ elements = { 512, CEIL_DIV(ne, 512), 1 };
7371
+ } else {
7372
+ elements = { ne, 1, 1 };
7373
+ }
7374
+ }
7375
+ break;
6938
7376
  default:
6939
7377
  elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
6940
7378
  break;
@@ -6955,7 +7393,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6955
7393
  }
6956
7394
  }
6957
7395
 
6958
- if (op == GGML_OP_SOFT_MAX) {
7396
+ if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
6959
7397
  // Empty src1 is possible in soft_max, but the shader needs a buffer
6960
7398
  vk_subbuffer subbuf_y;
6961
7399
  if (use_src1) {
@@ -7344,14 +7782,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
7344
7782
 
7345
7783
  static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7346
7784
  const uint32_t src0_type_size = ggml_type_size(src0->type);
7785
+ const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
7786
+
7787
+ float sf0 = (float)dst->ne[0] / src0->ne[0];
7788
+ float sf1 = (float)dst->ne[1] / src0->ne[1];
7789
+ float sf2 = (float)dst->ne[2] / src0->ne[2];
7790
+ float sf3 = (float)dst->ne[3] / src0->ne[3];
7347
7791
 
7348
- const float sf0 = (float)dst->ne[0] / src0->ne[0];
7349
- const float sf1 = (float)dst->ne[1] / src0->ne[1];
7350
- const float sf2 = (float)dst->ne[2] / src0->ne[2];
7351
- const float sf3 = (float)dst->ne[3] / src0->ne[3];
7792
+ if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7793
+ sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
7794
+ sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
7795
+ }
7352
7796
 
7353
7797
  ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
7354
7798
  (uint32_t)ggml_nelements(dst), 0, 0,
7799
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
7355
7800
  (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7356
7801
  (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
7357
7802
  sf0, sf1, sf2, sf3,
@@ -7359,123 +7804,64 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
7359
7804
  }
7360
7805
 
7361
7806
  static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7362
- float * op_params = (float *)dst->op_params;
7363
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7364
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7807
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7808
+ p.param1 = ggml_get_op_params_f32(dst, 0);
7809
+ p.param2 = ggml_get_op_params_f32(dst, 1);
7365
7810
 
7366
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
7367
- (uint32_t)ggml_nelements(src0),
7368
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7369
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7370
- 0,
7371
- op_params[0], 0.0f,
7372
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7373
- }, dryrun);
7811
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
7374
7812
  }
7375
7813
 
7376
7814
  static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7377
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7378
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7379
-
7380
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
7381
- (uint32_t)ggml_nelements(src0),
7382
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7383
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7384
- 0,
7385
- 0.0f, 0.0f,
7386
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7387
- }, dryrun);
7815
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
7388
7816
  }
7389
7817
 
7390
7818
  static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7391
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7392
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7393
-
7394
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
7395
- (uint32_t)ggml_nelements(src0),
7396
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7397
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7398
- 0,
7399
- 0.0f, 0.0f,
7400
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7401
- }, dryrun);
7819
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
7402
7820
  }
7403
7821
 
7404
7822
  static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7405
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7406
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7407
-
7408
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
7409
- (uint32_t)ggml_nelements(src0),
7410
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7411
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7412
- 0,
7413
- 0.0f, 0.0f,
7414
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7415
- }, dryrun);
7823
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
7416
7824
  }
7417
7825
 
7418
7826
  static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7419
- float * op_params = (float *)dst->op_params;
7420
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7421
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7827
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7828
+ p.param1 = ggml_get_op_params_f32(dst, 0);
7829
+ p.param2 = ggml_get_op_params_f32(dst, 1);
7422
7830
 
7423
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
7424
- (uint32_t)ggml_nelements(src0),
7425
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7426
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7427
- 0,
7428
- op_params[0], op_params[1],
7429
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7430
- }, dryrun);
7831
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
7431
7832
  }
7432
7833
 
7433
7834
  static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7434
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7435
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7835
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7836
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
7837
+ }
7436
7838
 
7437
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
7438
- (uint32_t)ggml_nelements(dst),
7439
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7440
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7441
- 0,
7442
- 0.0f, 0.0f,
7443
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7444
- }, dryrun);
7839
+ static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7840
+ const int32_t s0 = ggml_get_op_params_i32(dst, 0);
7841
+ const int32_t s1 = ggml_get_op_params_i32(dst, 1);
7842
+ const int32_t s2 = ggml_get_op_params_i32(dst, 2);
7843
+ const int32_t s3 = ggml_get_op_params_i32(dst, 3);
7844
+ const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
7845
+ const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
7846
+
7847
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7848
+ memcpy(&p.param1, &s01_packed, sizeof(float));
7849
+ memcpy(&p.param2, &s23_packed, sizeof(float));
7850
+
7851
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
7445
7852
  }
7446
7853
 
7447
7854
  static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7448
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7449
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7450
-
7451
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
7452
- (uint32_t)ggml_nelements(dst),
7453
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7454
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7455
- 0,
7456
- 0.0f, 0.0f,
7457
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7458
- }, dryrun);
7855
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7856
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
7459
7857
  }
7460
7858
 
7461
7859
  static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7462
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7463
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7464
-
7465
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
7466
- (uint32_t)ggml_nelements(dst),
7467
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7468
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7469
- 0,
7470
- 0.0f, 0.0f,
7471
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7472
- }, dryrun);
7860
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7861
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
7473
7862
  }
7474
7863
 
7475
7864
  static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7476
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7477
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7478
-
7479
7865
  uint32_t ne = (uint32_t)ggml_nelements(src0);
7480
7866
  if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
7481
7867
  // Convert from number of logical elements to 2- or 4-byte units.
@@ -7487,13 +7873,22 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
7487
7873
  }
7488
7874
  }
7489
7875
 
7490
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
7491
- ne,
7492
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7493
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7876
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
7877
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
7878
+ }
7879
+
7880
+ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7881
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7882
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
7883
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7884
+
7885
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
7886
+ (uint32_t)ggml_nelements(src0),
7887
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7888
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7889
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7494
7890
  0,
7495
- 0.0f, 0.0f,
7496
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7891
+ 0.0f, 0.0f, 0,
7497
7892
  }, dryrun);
7498
7893
  }
7499
7894
 
@@ -7518,18 +7913,18 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
7518
7913
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
7519
7914
  }
7520
7915
 
7521
- static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7522
- float * op_params = (float *)dst->op_params;
7916
+ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
7523
7917
  const uint32_t src0_type_size = ggml_type_size(src0->type);
7918
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
7524
7919
  const uint32_t dst_type_size = ggml_type_size(dst->type);
7525
7920
 
7526
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
7921
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
7527
7922
  (uint32_t)ggml_nelements(src0),
7528
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7529
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7923
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7924
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7925
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7530
7926
  0,
7531
- op_params[0], 0.0f,
7532
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7927
+ op_params[0], 0.0f, 0,
7533
7928
  }, dryrun);
7534
7929
  }
7535
7930
 
@@ -7547,6 +7942,25 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
7547
7942
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7548
7943
  }
7549
7944
 
7945
+ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7946
+ const bool swapped = (bool)dst->op_params[1];
7947
+ const bool split = src1 != nullptr;
7948
+
7949
+ GGML_ASSERT(ggml_is_contiguous(src0));
7950
+
7951
+ if (!split) {
7952
+ GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7953
+ } else {
7954
+ GGML_ASSERT(src0->ne[0] == src1->ne[0]);
7955
+ GGML_ASSERT(src0->ne[0] == dst->ne[0]);
7956
+ GGML_ASSERT(src0->type == src1->type);
7957
+ }
7958
+
7959
+ const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
7960
+
7961
+ ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
7962
+ }
7963
+
7550
7964
  static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7551
7965
  int32_t * op_params = (int32_t *)dst->op_params;
7552
7966
  ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
@@ -7562,7 +7976,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7562
7976
  const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
7563
7977
  const uint32_t nrows_y = (uint32_t)src0->ne[1];
7564
7978
 
7565
- const uint32_t n_head_kv = nrows_x/nrows_y;
7979
+ const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
7980
+ const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
7981
+ const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
7982
+ const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
7983
+ const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
7984
+
7985
+ const uint32_t n_head_kv = src0->ne[2];
7566
7986
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
7567
7987
 
7568
7988
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -7571,6 +7991,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7571
7991
  ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
7572
7992
  ncols,
7573
7993
  src1 != nullptr ? nrows_y : (uint32_t)0,
7994
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
7995
+ ne12, ne13,
7996
+ nb11, nb12, nb13,
7574
7997
  scale, max_bias,
7575
7998
  m0, m1,
7576
7999
  n_head_log2,
@@ -7753,6 +8176,55 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
7753
8176
  }, dryrun);
7754
8177
  }
7755
8178
 
8179
+ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
8180
+ const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
8181
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
8182
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
8183
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
8184
+
8185
+ GGML_TENSOR_BINARY_OP_LOCALS
8186
+
8187
+ GGML_ASSERT(nb00 == sizeof(float));
8188
+ GGML_ASSERT(nb10 == sizeof(float));
8189
+ GGML_ASSERT(nb0 == sizeof(float));
8190
+
8191
+ vk_op_conv2d_push_constants p{};
8192
+ p.Cout = static_cast<uint32_t>(ne03);
8193
+ p.Cin = static_cast<uint32_t>(ne02);
8194
+ p.N = static_cast<uint32_t>(ne13);
8195
+
8196
+ p.KW = static_cast<uint32_t>(ne00);
8197
+ p.KH = static_cast<uint32_t>(ne01);
8198
+ p.W = static_cast<uint32_t>(ne10);
8199
+ p.H = static_cast<uint32_t>(ne11);
8200
+ p.OW = static_cast<uint32_t>(ne0);
8201
+ p.OH = static_cast<uint32_t>(ne1);
8202
+
8203
+ p.s0 = static_cast<uint32_t>(dst->op_params[0]);
8204
+ p.s1 = static_cast<uint32_t>(dst->op_params[1]);
8205
+ p.p0 = static_cast<uint32_t>(dst->op_params[2]);
8206
+ p.p1 = static_cast<uint32_t>(dst->op_params[3]);
8207
+ p.d0 = static_cast<uint32_t>(dst->op_params[4]);
8208
+ p.d1 = static_cast<uint32_t>(dst->op_params[5]);
8209
+
8210
+ p.nb01 = static_cast<uint32_t>(nb01 / nb00);
8211
+ p.nb02 = static_cast<uint32_t>(nb02 / nb00);
8212
+ p.nb03 = static_cast<uint32_t>(nb03 / nb00);
8213
+
8214
+ p.nb11 = static_cast<uint32_t>(nb11 / nb10);
8215
+ p.nb12 = static_cast<uint32_t>(nb12 / nb10);
8216
+ p.nb13 = static_cast<uint32_t>(nb13 / nb10);
8217
+
8218
+ p.nb1 = static_cast<uint32_t>(nb1 / nb0);
8219
+ p.nb2 = static_cast<uint32_t>(nb2 / nb0);
8220
+ p.nb3 = static_cast<uint32_t>(nb3 / nb0);
8221
+
8222
+ GGML_ASSERT(ne03 == ne2);
8223
+ GGML_ASSERT(ne02 == ne12);
8224
+
8225
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun);
8226
+ }
8227
+
7756
8228
  static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7757
8229
  vk_op_conv2d_dw_push_constants p{};
7758
8230
  p.ne = ggml_nelements(dst);
@@ -8720,11 +9192,12 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
8720
9192
  }
8721
9193
  }
8722
9194
 
8723
- static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
9195
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
8724
9196
 
8725
9197
  // Returns true if node has enqueued work into the queue, false otherwise
8726
9198
  // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
8727
- static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
9199
+ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
9200
+ ggml_tensor * node = cgraph->nodes[node_idx];
8728
9201
  if (ggml_is_empty(node) || !node->buffer) {
8729
9202
  return false;
8730
9203
  }
@@ -8749,6 +9222,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8749
9222
  switch (ggml_get_unary_op(node)) {
8750
9223
  case GGML_UNARY_OP_SILU:
8751
9224
  case GGML_UNARY_OP_GELU:
9225
+ case GGML_UNARY_OP_GELU_ERF:
8752
9226
  case GGML_UNARY_OP_GELU_QUICK:
8753
9227
  case GGML_UNARY_OP_RELU:
8754
9228
  case GGML_UNARY_OP_TANH:
@@ -8758,6 +9232,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8758
9232
  return false;
8759
9233
  }
8760
9234
  break;
9235
+ case GGML_OP_GLU:
9236
+ switch (ggml_get_glu_op(node)) {
9237
+ case GGML_GLU_OP_GEGLU:
9238
+ case GGML_GLU_OP_REGLU:
9239
+ case GGML_GLU_OP_SWIGLU:
9240
+ case GGML_GLU_OP_GEGLU_ERF:
9241
+ case GGML_GLU_OP_GEGLU_QUICK:
9242
+ break;
9243
+ default:
9244
+ return false;
9245
+ }
9246
+ break;
8761
9247
  case GGML_OP_REPEAT:
8762
9248
  case GGML_OP_REPEAT_BACK:
8763
9249
  case GGML_OP_GET_ROWS:
@@ -8774,7 +9260,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8774
9260
  case GGML_OP_COS:
8775
9261
  case GGML_OP_CLAMP:
8776
9262
  case GGML_OP_PAD:
9263
+ case GGML_OP_ROLL:
8777
9264
  case GGML_OP_CPY:
9265
+ case GGML_OP_SET_ROWS:
8778
9266
  case GGML_OP_CONT:
8779
9267
  case GGML_OP_DUP:
8780
9268
  case GGML_OP_SILU_BACK:
@@ -8799,6 +9287,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8799
9287
  case GGML_OP_TIMESTEP_EMBEDDING:
8800
9288
  case GGML_OP_CONV_TRANSPOSE_1D:
8801
9289
  case GGML_OP_POOL_2D:
9290
+ case GGML_OP_CONV_2D:
8802
9291
  case GGML_OP_CONV_2D_DW:
8803
9292
  case GGML_OP_RWKV_WKV6:
8804
9293
  case GGML_OP_RWKV_WKV7:
@@ -8841,6 +9330,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8841
9330
  case GGML_OP_CLAMP:
8842
9331
  case GGML_OP_PAD:
8843
9332
  case GGML_OP_CPY:
9333
+ case GGML_OP_SET_ROWS:
8844
9334
  case GGML_OP_CONT:
8845
9335
  case GGML_OP_DUP:
8846
9336
  case GGML_OP_SILU_BACK:
@@ -8850,6 +9340,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8850
9340
  case GGML_OP_RMS_NORM_BACK:
8851
9341
  case GGML_OP_L2_NORM:
8852
9342
  case GGML_OP_UNARY:
9343
+ case GGML_OP_GLU:
8853
9344
  case GGML_OP_DIAG_MASK_INF:
8854
9345
  case GGML_OP_SOFT_MAX:
8855
9346
  case GGML_OP_SOFT_MAX_BACK:
@@ -8864,6 +9355,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8864
9355
  case GGML_OP_TIMESTEP_EMBEDDING:
8865
9356
  case GGML_OP_CONV_TRANSPOSE_1D:
8866
9357
  case GGML_OP_POOL_2D:
9358
+ case GGML_OP_CONV_2D:
8867
9359
  case GGML_OP_CONV_2D_DW:
8868
9360
  case GGML_OP_LEAKY_RELU:
8869
9361
  {
@@ -8942,12 +9434,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8942
9434
  case GGML_OP_PAD:
8943
9435
  ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
8944
9436
 
9437
+ break;
9438
+ case GGML_OP_ROLL:
9439
+ ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
9440
+
8945
9441
  break;
8946
9442
  case GGML_OP_CPY:
8947
9443
  case GGML_OP_CONT:
8948
9444
  case GGML_OP_DUP:
8949
9445
  ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
8950
9446
 
9447
+ break;
9448
+ case GGML_OP_SET_ROWS:
9449
+ ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun);
9450
+
8951
9451
  break;
8952
9452
  case GGML_OP_SILU_BACK:
8953
9453
  ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -8962,8 +9462,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8962
9462
 
8963
9463
  break;
8964
9464
  case GGML_OP_RMS_NORM:
8965
- ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
8966
-
9465
+ if (ctx->num_additional_fused_ops > 0) {
9466
+ // fused rms_norm + mul
9467
+ ggml_tensor *mul = cgraph->nodes[node_idx + 1];
9468
+ ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
9469
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
9470
+ } else {
9471
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
9472
+ }
8967
9473
  break;
8968
9474
  case GGML_OP_RMS_NORM_BACK:
8969
9475
  ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -8977,6 +9483,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8977
9483
  switch (ggml_get_unary_op(node)) {
8978
9484
  case GGML_UNARY_OP_SILU:
8979
9485
  case GGML_UNARY_OP_GELU:
9486
+ case GGML_UNARY_OP_GELU_ERF:
8980
9487
  case GGML_UNARY_OP_GELU_QUICK:
8981
9488
  case GGML_UNARY_OP_RELU:
8982
9489
  case GGML_UNARY_OP_TANH:
@@ -8987,6 +9494,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8987
9494
  return false;
8988
9495
  }
8989
9496
  break;
9497
+ case GGML_OP_GLU:
9498
+ switch (ggml_get_glu_op(node)) {
9499
+ case GGML_GLU_OP_GEGLU:
9500
+ case GGML_GLU_OP_REGLU:
9501
+ case GGML_GLU_OP_SWIGLU:
9502
+ case GGML_GLU_OP_GEGLU_ERF:
9503
+ case GGML_GLU_OP_GEGLU_QUICK:
9504
+ ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
9505
+ break;
9506
+ default:
9507
+ return false;
9508
+ }
9509
+ break;
8990
9510
  case GGML_OP_DIAG_MASK_INF:
8991
9511
  ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
8992
9512
 
@@ -9042,6 +9562,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
9042
9562
  case GGML_OP_POOL_2D:
9043
9563
  ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
9044
9564
 
9565
+ break;
9566
+ case GGML_OP_CONV_2D:
9567
+ ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun);
9568
+
9045
9569
  break;
9046
9570
  case GGML_OP_CONV_2D_DW:
9047
9571
  ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9108,12 +9632,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
9108
9632
 
9109
9633
  ctx->compute_ctx.reset();
9110
9634
 
9111
- bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
9635
+ bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
9112
9636
  if (!ok) {
9113
9637
  if (node->op == GGML_OP_UNARY) {
9114
9638
  std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
9115
- }
9116
- else {
9639
+ } else if (node->op == GGML_OP_GLU) {
9640
+ std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
9641
+ } else {
9117
9642
  std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
9118
9643
  }
9119
9644
  }
@@ -9122,7 +9647,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
9122
9647
  return true;
9123
9648
  }
9124
9649
 
9125
- static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9650
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9651
+ GGML_UNUSED(cgraph);
9126
9652
  ggml_backend_buffer * buf = nullptr;
9127
9653
 
9128
9654
  switch (tensor->op) {
@@ -9140,7 +9666,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9140
9666
  case GGML_OP_COS:
9141
9667
  case GGML_OP_CLAMP:
9142
9668
  case GGML_OP_PAD:
9669
+ case GGML_OP_ROLL:
9143
9670
  case GGML_OP_CPY:
9671
+ case GGML_OP_SET_ROWS:
9144
9672
  case GGML_OP_CONT:
9145
9673
  case GGML_OP_DUP:
9146
9674
  case GGML_OP_SILU_BACK:
@@ -9168,6 +9696,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9168
9696
  case GGML_OP_TIMESTEP_EMBEDDING:
9169
9697
  case GGML_OP_CONV_TRANSPOSE_1D:
9170
9698
  case GGML_OP_POOL_2D:
9699
+ case GGML_OP_CONV_2D:
9171
9700
  case GGML_OP_CONV_2D_DW:
9172
9701
  case GGML_OP_RWKV_WKV6:
9173
9702
  case GGML_OP_RWKV_WKV7:
@@ -9182,6 +9711,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9182
9711
  switch (ggml_get_unary_op(tensor)) {
9183
9712
  case GGML_UNARY_OP_SILU:
9184
9713
  case GGML_UNARY_OP_GELU:
9714
+ case GGML_UNARY_OP_GELU_ERF:
9185
9715
  case GGML_UNARY_OP_GELU_QUICK:
9186
9716
  case GGML_UNARY_OP_RELU:
9187
9717
  case GGML_UNARY_OP_TANH:
@@ -9192,6 +9722,19 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9192
9722
  return false;
9193
9723
  }
9194
9724
  break;
9725
+ case GGML_OP_GLU:
9726
+ switch (ggml_get_glu_op(tensor)) {
9727
+ case GGML_GLU_OP_GEGLU:
9728
+ case GGML_GLU_OP_REGLU:
9729
+ case GGML_GLU_OP_SWIGLU:
9730
+ case GGML_GLU_OP_GEGLU_ERF:
9731
+ case GGML_GLU_OP_GEGLU_QUICK:
9732
+ buf = tensor->buffer;
9733
+ break;
9734
+ default:
9735
+ return false;
9736
+ }
9737
+ break;
9195
9738
  case GGML_OP_MUL_MAT:
9196
9739
  case GGML_OP_MUL_MAT_ID:
9197
9740
  case GGML_OP_FLASH_ATTN_EXT:
@@ -9218,7 +9761,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9218
9761
  // Only run if ctx hasn't been submitted yet
9219
9762
  if (!subctx->seqs.empty()) {
9220
9763
  #ifdef GGML_VULKAN_CHECK_RESULTS
9221
- ggml_vk_check_results_0(tensor);
9764
+ ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
9222
9765
  use_fence = true;
9223
9766
  #endif
9224
9767
 
@@ -9238,7 +9781,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9238
9781
  ggml_vk_wait_for_fence(ctx);
9239
9782
  }
9240
9783
  #ifdef GGML_VULKAN_CHECK_RESULTS
9241
- ggml_vk_check_results_1(tensor);
9784
+ ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
9242
9785
  #endif
9243
9786
  }
9244
9787
 
@@ -9685,6 +10228,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
9685
10228
  return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
9686
10229
  }
9687
10230
 
10231
+ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
10232
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
10233
+ return false;
10234
+ }
10235
+
10236
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
10237
+ // additional constraints specific to this fusion
10238
+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
10239
+ const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
10240
+
10241
+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
10242
+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
10243
+ // rms_norm only supports f32
10244
+ if (mul->src[0]->type != GGML_TYPE_F32 ||
10245
+ mul->src[1]->type != GGML_TYPE_F32 ||
10246
+ mul->type != GGML_TYPE_F32) {
10247
+ return false;
10248
+ }
10249
+ // if rms_norm is the B operand, then we don't handle broadcast
10250
+ if (rms_norm == mul->src[1] &&
10251
+ !ggml_are_same_shape(mul->src[0], rms_norm)) {
10252
+ return false;
10253
+ }
10254
+ // rms_norm shader assumes contiguous rows
10255
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
10256
+ return false;
10257
+ }
10258
+ }
10259
+ return true;
10260
+ }
10261
+
9688
10262
  static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
9689
10263
  VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
9690
10264
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -9698,10 +10272,21 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9698
10272
 
9699
10273
  uint64_t total_mat_mul_bytes = 0;
9700
10274
  for (int i = 0; i < cgraph->n_nodes; i++) {
9701
- ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
10275
+ if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10276
+ ctx->num_additional_fused_ops = 1;
10277
+ }
10278
+ ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
9702
10279
  if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
9703
10280
  total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
10281
+ } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) {
10282
+ // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode.
10283
+ auto CRS_size =
10284
+ cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2];
10285
+ auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3];
10286
+ total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type);
9704
10287
  }
10288
+ i += ctx->num_additional_fused_ops;
10289
+ ctx->num_additional_fused_ops = 0;
9705
10290
  }
9706
10291
  if (ctx->device->need_compiles) {
9707
10292
  ggml_vk_load_shaders(ctx->device);
@@ -9763,14 +10348,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9763
10348
  mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
9764
10349
  }
9765
10350
 
10351
+ if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10352
+ ctx->num_additional_fused_ops = 1;
10353
+ }
10354
+
9766
10355
  // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
9767
10356
  bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
9768
10357
  bool submit = (submitted_nodes >= nodes_per_submit) ||
9769
10358
  (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
9770
- (i == last_node) ||
10359
+ (i + ctx->num_additional_fused_ops == last_node) ||
9771
10360
  (almost_ready && !ctx->almost_ready_fence_pending);
9772
10361
 
9773
- bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
10362
+ bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
9774
10363
 
9775
10364
  if (vk_perf_logger_enabled) {
9776
10365
  if (ctx->compute_ctx.expired()) {
@@ -9780,7 +10369,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9780
10369
  } else {
9781
10370
  compute_ctx = ctx->compute_ctx.lock();
9782
10371
  }
9783
- compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
10372
+ // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
10373
+ for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
10374
+ compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
10375
+ }
9784
10376
  }
9785
10377
 
9786
10378
  if (enqueued) {
@@ -9802,6 +10394,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9802
10394
  }
9803
10395
  submit_count++;
9804
10396
  }
10397
+ i += ctx->num_additional_fused_ops;
10398
+ ctx->num_additional_fused_ops = 0;
9805
10399
  }
9806
10400
 
9807
10401
  if (vk_perf_logger_enabled) {
@@ -9963,6 +10557,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9963
10557
  case GGML_OP_UNARY:
9964
10558
  switch (ggml_get_unary_op(op)) {
9965
10559
  case GGML_UNARY_OP_GELU:
10560
+ case GGML_UNARY_OP_GELU_ERF:
9966
10561
  case GGML_UNARY_OP_GELU_QUICK:
9967
10562
  case GGML_UNARY_OP_SILU:
9968
10563
  case GGML_UNARY_OP_RELU:
@@ -9976,15 +10571,32 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9976
10571
  return false;
9977
10572
  }
9978
10573
  break;
10574
+ case GGML_OP_GLU:
10575
+ switch (ggml_get_glu_op(op)) {
10576
+ case GGML_GLU_OP_GEGLU:
10577
+ case GGML_GLU_OP_REGLU:
10578
+ case GGML_GLU_OP_SWIGLU:
10579
+ case GGML_GLU_OP_GEGLU_ERF:
10580
+ case GGML_GLU_OP_GEGLU_QUICK:
10581
+ return ggml_is_contiguous(op->src[0]) &&
10582
+ (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10583
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
10584
+ (op->src[0]->type == op->type);
10585
+ default:
10586
+ return false;
10587
+ }
10588
+ break;
9979
10589
  case GGML_OP_MUL_MAT:
9980
10590
  case GGML_OP_MUL_MAT_ID:
9981
10591
  {
9982
10592
  ggml_type src0_type = op->src[0]->type;
9983
10593
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
9984
10594
  const vk_device& device = ggml_vk_get_device(ctx->device);
9985
- if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
9986
- // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
9987
- return false;
10595
+ if (op->op == GGML_OP_MUL_MAT_ID) {
10596
+ if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
10597
+ // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
10598
+ return false;
10599
+ }
9988
10600
  }
9989
10601
  switch (src0_type) {
9990
10602
  case GGML_TYPE_F32:
@@ -10042,19 +10654,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10042
10654
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
10043
10655
  auto device = ggml_vk_get_device(ctx->device);
10044
10656
  bool coopmat2 = device->coopmat2;
10045
- switch (op->src[0]->ne[0]) {
10046
- case 64:
10047
- case 80:
10048
- case 96:
10049
- case 112:
10050
- case 128:
10051
- case 256:
10052
- break;
10053
- default:
10054
- return false;
10055
- }
10056
- if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
10057
- // different head sizes of K and V are not supported yet
10657
+ FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
10658
+ if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
10058
10659
  return false;
10059
10660
  }
10060
10661
  if (op->src[0]->type != GGML_TYPE_F32) {
@@ -10134,6 +10735,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10134
10735
  return false;
10135
10736
  }
10136
10737
  } break;
10738
+ case GGML_OP_SET_ROWS:
10739
+ {
10740
+ switch (op->type) {
10741
+ case GGML_TYPE_F32:
10742
+ case GGML_TYPE_F16:
10743
+ case GGML_TYPE_BF16:
10744
+ case GGML_TYPE_Q4_0:
10745
+ case GGML_TYPE_Q4_1:
10746
+ case GGML_TYPE_Q5_0:
10747
+ case GGML_TYPE_Q5_1:
10748
+ case GGML_TYPE_Q8_0:
10749
+ case GGML_TYPE_IQ4_NL:
10750
+ return true;
10751
+ default:
10752
+ return false;
10753
+ }
10754
+ } break;
10137
10755
  case GGML_OP_CONT:
10138
10756
  case GGML_OP_CPY:
10139
10757
  case GGML_OP_DUP:
@@ -10218,11 +10836,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10218
10836
  case GGML_OP_CLAMP:
10219
10837
  return op->src[0]->type == GGML_TYPE_F32;
10220
10838
  case GGML_OP_UPSCALE:
10221
- return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
10222
10839
  case GGML_OP_ACC:
10223
10840
  case GGML_OP_CONCAT:
10224
10841
  case GGML_OP_SCALE:
10225
10842
  case GGML_OP_PAD:
10843
+ case GGML_OP_ROLL:
10226
10844
  case GGML_OP_DIAG_MASK_INF:
10227
10845
  case GGML_OP_SOFT_MAX:
10228
10846
  case GGML_OP_SOFT_MAX_BACK:
@@ -10242,6 +10860,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10242
10860
  return true;
10243
10861
  case GGML_OP_CONV_TRANSPOSE_1D:
10244
10862
  return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
10863
+ case GGML_OP_CONV_2D:
10864
+ {
10865
+ // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
10866
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
10867
+ const vk_device& device = ggml_vk_get_device(ctx->device);
10868
+ bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
10869
+ // Channel-contiguous format is not supported yet.
10870
+ return (op->src[0]->type == GGML_TYPE_F32 &&
10871
+ op->src[1]->type == GGML_TYPE_F32 &&
10872
+ op->type == GGML_TYPE_F32 &&
10873
+ ggml_is_contiguous(op->src[0]) &&
10874
+ ggml_is_contiguous(op->src[1]) &&
10875
+ ggml_is_contiguous(op)) && !is_Apple;
10876
+ }
10245
10877
  default:
10246
10878
  return false;
10247
10879
  }
@@ -10513,11 +11145,21 @@ void * comp_result;
10513
11145
  size_t comp_size;
10514
11146
  size_t comp_nb[GGML_MAX_DIMS];
10515
11147
  size_t check_counter = 0;
10516
- static void ggml_vk_check_results_0(ggml_tensor * tensor) {
11148
+ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
11149
+ ggml_tensor * tensor = cgraph->nodes[tensor_idx];
10517
11150
  if (tensor->op == GGML_OP_TRANSPOSE) {
10518
11151
  return;
10519
11152
  }
10520
11153
 
11154
+ bool fused_rms_norm_mul = false;
11155
+ int rms_norm_idx = -1;
11156
+ if (ctx->num_additional_fused_ops == 1 &&
11157
+ tensor->op == GGML_OP_RMS_NORM &&
11158
+ cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
11159
+ fused_rms_norm_mul = true;
11160
+ tensor = cgraph->nodes[tensor_idx + 1];
11161
+ }
11162
+
10521
11163
  check_counter++;
10522
11164
  if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
10523
11165
  return;
@@ -10545,6 +11187,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10545
11187
 
10546
11188
  for (int i = 0; i < 6; i++) {
10547
11189
  ggml_tensor * srci = tensor->src[i];
11190
+ if (fused_rms_norm_mul) {
11191
+ rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
11192
+ ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
11193
+ switch (i) {
11194
+ case 0: srci = rms_norm->src[0]; break;
11195
+ case 1: srci = tensor->src[1 - rms_norm_idx]; break;
11196
+ default: continue;
11197
+ }
11198
+ }
10548
11199
  if (srci == nullptr) {
10549
11200
  continue;
10550
11201
  }
@@ -10602,7 +11253,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10602
11253
  } else if (tensor->op == GGML_OP_SUB) {
10603
11254
  tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
10604
11255
  } else if (tensor->op == GGML_OP_MUL) {
10605
- tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
11256
+ if (fused_rms_norm_mul) {
11257
+ tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
11258
+ tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
11259
+ } else {
11260
+ tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
11261
+ }
10606
11262
  } else if (tensor->op == GGML_OP_DIV) {
10607
11263
  tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
10608
11264
  } else if (tensor->op == GGML_OP_CONCAT) {
@@ -10690,6 +11346,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10690
11346
  case GGML_UNARY_OP_GELU:
10691
11347
  tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
10692
11348
  break;
11349
+ case GGML_UNARY_OP_GELU_ERF:
11350
+ tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
11351
+ break;
10693
11352
  case GGML_UNARY_OP_GELU_QUICK:
10694
11353
  tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
10695
11354
  break;
@@ -10706,6 +11365,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10706
11365
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
10707
11366
  GGML_ABORT("fatal error");
10708
11367
  }
11368
+ } else if (tensor->op == GGML_OP_GLU) {
11369
+ if (src_clone[1] == nullptr) {
11370
+ tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
11371
+ } else {
11372
+ tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
11373
+ }
10709
11374
  } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
10710
11375
  if (src1 == nullptr) {
10711
11376
  tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
@@ -10713,6 +11378,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10713
11378
  } else {
10714
11379
  tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
10715
11380
  }
11381
+ } else if (tensor->op == GGML_OP_SET_ROWS) {
11382
+ tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
10716
11383
  } else if (tensor->op == GGML_OP_CONT) {
10717
11384
  tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
10718
11385
  } else if (tensor->op == GGML_OP_RESHAPE) {
@@ -10765,6 +11432,14 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10765
11432
  const int32_t p1 = tensor->op_params[6];
10766
11433
 
10767
11434
  tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
11435
+ } else if (tensor->op == GGML_OP_CONV_2D) {
11436
+ const int32_t s0 = tensor->op_params[0];
11437
+ const int32_t s1 = tensor->op_params[1];
11438
+ const int32_t p0 = tensor->op_params[2];
11439
+ const int32_t p1 = tensor->op_params[3];
11440
+ const int32_t d0 = tensor->op_params[4];
11441
+ const int32_t d1 = tensor->op_params[5];
11442
+ tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
10768
11443
  } else if (tensor->op == GGML_OP_LEAKY_RELU) {
10769
11444
  const float * op_params = (const float *)tensor->op_params;
10770
11445
  tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
@@ -10784,10 +11459,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10784
11459
  GGML_ABORT("fatal error");
10785
11460
  }
10786
11461
 
10787
- ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
10788
- ggml_build_forward_expand(cgraph, tensor_clone);
11462
+ ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
11463
+ ggml_build_forward_expand(cgraph_cpu, tensor_clone);
10789
11464
 
10790
- ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
11465
+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
10791
11466
 
10792
11467
  if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
10793
11468
  ggml_vk_print_tensor(tensor_clone, "tensor_clone");
@@ -10810,10 +11485,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10810
11485
  VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
10811
11486
  }
10812
11487
 
10813
- static void ggml_vk_check_results_1(ggml_tensor * tensor) {
11488
+ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
11489
+ ggml_tensor * tensor = cgraph->nodes[tensor_idx];
10814
11490
  if (tensor->op == GGML_OP_TRANSPOSE) {
10815
11491
  return;
10816
11492
  }
11493
+ bool fused_rms_norm_mul = false;
11494
+ if (ctx->num_additional_fused_ops == 1 &&
11495
+ tensor->op == GGML_OP_RMS_NORM &&
11496
+ cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
11497
+ fused_rms_norm_mul = true;
11498
+ tensor = cgraph->nodes[tensor_idx + 1];
11499
+ }
11500
+
10817
11501
  if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
10818
11502
  return;
10819
11503
  }