@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -224,6 +224,21 @@ enum vk_device_architecture {
224
224
  INTEL_XE2,
225
225
  };
226
226
 
227
+ // HSK x HSV
228
+ enum FaHeadSizes {
229
+ FA_HEAD_SIZE_64,
230
+ FA_HEAD_SIZE_80,
231
+ FA_HEAD_SIZE_96,
232
+ FA_HEAD_SIZE_112,
233
+ FA_HEAD_SIZE_128,
234
+ FA_HEAD_SIZE_192,
235
+ FA_HEAD_SIZE_192_128,
236
+ FA_HEAD_SIZE_256,
237
+ FA_HEAD_SIZE_576_512,
238
+ FA_HEAD_SIZE_UNSUPPORTED,
239
+ FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
240
+ };
241
+
227
242
  static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
228
243
  vk::PhysicalDeviceProperties props = device.getProperties();
229
244
 
@@ -305,7 +320,7 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
305
320
  }
306
321
 
307
322
  struct vk_device_struct {
308
- std::mutex mutex;
323
+ std::recursive_mutex mutex;
309
324
 
310
325
  vk::PhysicalDevice physical_device;
311
326
  vk::PhysicalDeviceProperties properties;
@@ -410,32 +425,42 @@ struct vk_device_struct {
410
425
  vk_pipeline pipeline_div_norepeat[2][2][2];
411
426
 
412
427
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
413
- vk_pipeline pipeline_upscale_f32;
428
+ vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
414
429
  vk_pipeline pipeline_scale_f32;
415
430
  vk_pipeline pipeline_sqr_f32;
416
431
  vk_pipeline pipeline_sin_f32;
417
432
  vk_pipeline pipeline_cos_f32;
418
433
  vk_pipeline pipeline_clamp_f32;
419
434
  vk_pipeline pipeline_pad_f32;
435
+ vk_pipeline pipeline_roll_f32;
420
436
  vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
421
437
  vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
422
438
  vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
423
439
  vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
424
440
  vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
441
+ vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
425
442
  vk_pipeline pipeline_norm_f32;
426
443
  vk_pipeline pipeline_group_norm_f32;
427
444
  vk_pipeline pipeline_rms_norm_f32;
445
+ vk_pipeline pipeline_rms_norm_mul_f32;
428
446
  vk_pipeline pipeline_rms_norm_back_f32;
429
447
  vk_pipeline pipeline_l2_norm_f32;
430
448
 
431
449
  // [src/dst 0=fp32,1=fp16]
432
450
  vk_pipeline pipeline_gelu[2];
451
+ vk_pipeline pipeline_gelu_erf[2];
433
452
  vk_pipeline pipeline_gelu_quick[2];
434
453
  vk_pipeline pipeline_silu[2];
435
454
  vk_pipeline pipeline_relu[2];
436
455
  vk_pipeline pipeline_tanh[2];
437
456
  vk_pipeline pipeline_sigmoid[2];
438
457
 
458
+ vk_pipeline pipeline_geglu[2];
459
+ vk_pipeline pipeline_reglu[2];
460
+ vk_pipeline pipeline_swiglu[2];
461
+ vk_pipeline pipeline_geglu_erf[2];
462
+ vk_pipeline pipeline_geglu_quick[2];
463
+
439
464
  vk_pipeline pipeline_leaky_relu_f32;
440
465
  vk_pipeline pipeline_silu_back_f32;
441
466
  vk_pipeline pipeline_diag_mask_inf_f32;
@@ -461,26 +486,11 @@ struct vk_device_struct {
461
486
  vk_pipeline pipeline_conv2d_dw_cwhn_f32;
462
487
 
463
488
  // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
464
- vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
465
- vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
466
- vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
467
- vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
468
- vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
469
- vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
470
-
471
- vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
472
- vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
473
- vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
474
- vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
475
- vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
476
- vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
477
-
478
- vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
479
- vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
480
- vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
481
- vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
482
- vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
483
- vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
489
+ vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
490
+
491
+ vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
492
+
493
+ vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
484
494
 
485
495
  vk_pipeline pipeline_flash_attn_split_k_reduce;
486
496
 
@@ -493,6 +503,8 @@ struct vk_device_struct {
493
503
 
494
504
  ggml_backend_buffer_type buffer_type;
495
505
 
506
+ bool disable_fusion;
507
+
496
508
  #ifdef GGML_VULKAN_MEMORY_DEBUG
497
509
  std::unique_ptr<vk_memory_logger> memory_logger;
498
510
  #endif
@@ -627,6 +639,8 @@ struct vk_flash_attn_push_constants {
627
639
  uint32_t nev2;
628
640
  uint32_t nev3;
629
641
  uint32_t nem1;
642
+ uint32_t nem2;
643
+ uint32_t nem3;
630
644
 
631
645
  uint32_t nb01;
632
646
  uint32_t nb02;
@@ -637,14 +651,12 @@ struct vk_flash_attn_push_constants {
637
651
  uint32_t nb21;
638
652
  uint32_t nb22;
639
653
  uint32_t nb23;
640
- uint32_t nb31;
641
654
 
642
655
  float scale;
643
656
  float max_bias;
644
657
  float logit_softcap;
645
658
 
646
- uint32_t mask;
647
- uint32_t n_head_log2;
659
+ uint32_t mask_n_head_log2;
648
660
  float m0;
649
661
  float m1;
650
662
 
@@ -652,6 +664,7 @@ struct vk_flash_attn_push_constants {
652
664
  uint32_t split_kv;
653
665
  uint32_t k_num;
654
666
  };
667
+ static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
655
668
 
656
669
  struct vk_op_push_constants {
657
670
  uint32_t KX;
@@ -660,6 +673,13 @@ struct vk_op_push_constants {
660
673
  float param2;
661
674
  };
662
675
 
676
+ struct vk_op_glu_push_constants {
677
+ uint32_t N;
678
+ uint32_t ne00;
679
+ uint32_t ne20;
680
+ uint32_t mode; // 0: default, 1: swapped, 2: split
681
+ };
682
+
663
683
  struct vk_op_unary_push_constants {
664
684
  uint32_t ne;
665
685
  uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -675,6 +695,37 @@ struct vk_op_unary_push_constants {
675
695
  };
676
696
  static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
677
697
 
698
+ static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
699
+ GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
700
+ ne = ne != 0 ? ne : ggml_nelements(dst);
701
+ GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
702
+
703
+ vk_op_unary_push_constants p{};
704
+ p.ne = (uint32_t)ne;
705
+
706
+ size_t src0_tsize = ggml_type_size(src0->type);
707
+ p.ne00 = (uint32_t)src0->ne[0];
708
+ p.ne01 = (uint32_t)src0->ne[1];
709
+ p.ne02 = (uint32_t)src0->ne[2];
710
+ p.ne03 = (uint32_t)src0->ne[3];
711
+ p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
712
+ p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
713
+ p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
714
+ p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
715
+
716
+ size_t dst_tsize = ggml_type_size(dst->type);
717
+ p.ne10 = (uint32_t)dst->ne[0];
718
+ p.ne11 = (uint32_t)dst->ne[1];
719
+ p.ne12 = (uint32_t)dst->ne[2];
720
+ p.ne13 = (uint32_t)dst->ne[3];
721
+ p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
722
+ p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
723
+ p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
724
+ p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
725
+
726
+ return p; // fastdiv values and offsets are initialized later in ggml_vk_op
727
+ }
728
+
678
729
  // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
679
730
  // Precompute mp (m' in the paper) and L such that division
680
731
  // can be computed using a multiply (high 32b of 64b result)
@@ -743,6 +794,14 @@ struct vk_op_rope_push_constants {
743
794
  struct vk_op_soft_max_push_constants {
744
795
  uint32_t KX;
745
796
  uint32_t KY;
797
+ uint32_t ne00;
798
+ uint32_t ne01;
799
+ uint32_t ne02;
800
+ uint32_t ne12;
801
+ uint32_t ne13;
802
+ uint32_t nb11;
803
+ uint32_t nb12;
804
+ uint32_t nb13;
746
805
  float scale;
747
806
  float max_bias;
748
807
  float m0;
@@ -836,6 +895,7 @@ struct vk_op_conv2d_dw_push_constants {
836
895
 
837
896
  struct vk_op_upscale_push_constants {
838
897
  uint32_t ne; uint32_t a_offset; uint32_t d_offset;
898
+ uint32_t ne00; uint32_t ne01;
839
899
  uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
840
900
  uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
841
901
  float sf0; float sf1; float sf2; float sf3;
@@ -978,6 +1038,10 @@ struct ggml_backend_vk_context {
978
1038
 
979
1039
  vk_command_pool compute_cmd_pool;
980
1040
  vk_command_pool transfer_cmd_pool;
1041
+
1042
+ // number of additional consecutive nodes that are being fused with the
1043
+ // node currently being processed
1044
+ int num_additional_fused_ops {};
981
1045
  };
982
1046
 
983
1047
  static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@@ -1041,6 +1105,14 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
1041
1105
  struct vk_instance_t {
1042
1106
  vk::Instance instance;
1043
1107
 
1108
+ bool debug_utils_support = false; // VK_EXT_debug_utils enabled
1109
+ PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {};
1110
+ PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {};
1111
+ PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {};
1112
+ PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {};
1113
+ PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {};
1114
+ PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {};
1115
+
1044
1116
  std::vector<size_t> device_indices;
1045
1117
  vk_device devices[GGML_VK_MAX_DEVICES];
1046
1118
  };
@@ -1055,8 +1127,8 @@ static size_t vk_skip_checks;
1055
1127
  static size_t vk_output_tensor;
1056
1128
 
1057
1129
  static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
1058
- static void ggml_vk_check_results_0(ggml_tensor * tensor);
1059
- static void ggml_vk_check_results_1(ggml_tensor * tensor);
1130
+ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
1131
+ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
1060
1132
  #endif
1061
1133
 
1062
1134
  typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@@ -1180,8 +1252,16 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
1180
1252
  }
1181
1253
  pipeline->compiled = true;
1182
1254
 
1255
+ if (vk_instance.debug_utils_support) {
1256
+ vk::DebugUtilsObjectNameInfoEXT duoni;
1257
+ duoni.objectType = vk::ObjectType::ePipeline;
1258
+ duoni.pObjectName = pipeline->name.c_str();
1259
+ duoni.objectHandle = reinterpret_cast<uint64_t>(static_cast<VkPipeline_T*>(pipeline->pipeline));
1260
+ vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
1261
+ }
1262
+
1183
1263
  {
1184
- std::lock_guard<std::mutex> guard(device->mutex);
1264
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
1185
1265
  device->pipelines.insert({ pipeline->name, pipeline });
1186
1266
  }
1187
1267
 
@@ -1395,7 +1475,7 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyPrope
1395
1475
 
1396
1476
  static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
1397
1477
  VK_LOG_DEBUG("ggml_vk_create_queue()");
1398
- std::lock_guard<std::mutex> guard(device->mutex);
1478
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
1399
1479
 
1400
1480
  q.queue_family_index = queue_family_index;
1401
1481
  q.transfer_only = transfer_only;
@@ -1657,10 +1737,46 @@ enum FaCodePath {
1657
1737
  FA_COOPMAT2,
1658
1738
  };
1659
1739
 
1740
+ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
1741
+ if (hsk != 192 && hsk != 576 && hsk != hsv) {
1742
+ return FA_HEAD_SIZE_UNSUPPORTED;
1743
+ }
1744
+ switch (hsk) {
1745
+ case 64: return FA_HEAD_SIZE_64;
1746
+ case 80: return FA_HEAD_SIZE_80;
1747
+ case 96: return FA_HEAD_SIZE_96;
1748
+ case 112: return FA_HEAD_SIZE_112;
1749
+ case 128: return FA_HEAD_SIZE_128;
1750
+ case 192:
1751
+ if (hsv == 192) {
1752
+ return FA_HEAD_SIZE_192;
1753
+ } else if (hsv == 128) {
1754
+ return FA_HEAD_SIZE_192_128;
1755
+ } else {
1756
+ return FA_HEAD_SIZE_UNSUPPORTED;
1757
+ }
1758
+ case 256: return FA_HEAD_SIZE_256;
1759
+ case 576:
1760
+ if (hsv == 512) {
1761
+ return FA_HEAD_SIZE_576_512;
1762
+ } else {
1763
+ return FA_HEAD_SIZE_UNSUPPORTED;
1764
+ }
1765
+ default: return FA_HEAD_SIZE_UNSUPPORTED;
1766
+ }
1767
+ }
1768
+
1660
1769
  // number of rows/cols for flash attention shader
1661
1770
  static constexpr uint32_t flash_attention_num_small_rows = 32;
1662
1771
  static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
1663
- static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
1772
+
1773
+ static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
1774
+ if (hsv >= 512) {
1775
+ return 2;
1776
+ } else {
1777
+ return 8;
1778
+ }
1779
+ }
1664
1780
 
1665
1781
  // The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
1666
1782
  // 128 threads split into four subgroups, each subgroup does 1/4
@@ -1677,14 +1793,15 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
1677
1793
  }
1678
1794
  }
1679
1795
 
1680
- static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
1796
+ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
1681
1797
  GGML_UNUSED(clamp);
1798
+ GGML_UNUSED(hsv);
1682
1799
 
1683
1800
  if (path == FA_SCALAR) {
1684
1801
  if (small_rows) {
1685
1802
  return {scalar_flash_attention_num_small_rows, 64};
1686
1803
  } else {
1687
- return {scalar_flash_attention_num_large_rows, 32};
1804
+ return {get_fa_scalar_num_large_rows(hsv), 32};
1688
1805
  }
1689
1806
  }
1690
1807
 
@@ -1702,8 +1819,12 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
1702
1819
  }
1703
1820
 
1704
1821
  // small cols to reduce register count
1705
- if (ggml_is_quantized(type) || D == 256) {
1706
- return {64, 32};
1822
+ if (ggml_is_quantized(type) || hsk >= 256) {
1823
+ if (hsk >= 512) {
1824
+ return {32, 32};
1825
+ } else {
1826
+ return {64, 32};
1827
+ }
1707
1828
  }
1708
1829
  return {64, 64};
1709
1830
  }
@@ -1745,7 +1866,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1745
1866
  const uint32_t warps = warptile[0] / warptile[10];
1746
1867
 
1747
1868
  const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
1748
- const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
1869
+ const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
1749
1870
  const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
1750
1871
 
1751
1872
  const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
@@ -1870,10 +1991,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
1870
1991
  s_mmq_wg_denoms_k = { 32, 32, 1 };
1871
1992
 
1872
1993
  // spec constants and tile sizes for quant matmul_id
1873
- l_warptile_mmqid = { 256, 128, 64, 16, 0 };
1994
+ l_warptile_mmqid = { 256, 128, 128, 16, 0 };
1874
1995
  m_warptile_mmqid = { 256, 128, 64, 16, 0 };
1875
1996
  s_warptile_mmqid = { 256, 128, 64, 16, 0 };
1876
- l_mmqid_wg_denoms = { 128, 64, 1 };
1997
+ l_mmqid_wg_denoms = { 128, 128, 1 };
1877
1998
  m_mmqid_wg_denoms = { 128, 64, 1 };
1878
1999
  s_mmqid_wg_denoms = { 128, 64, 1 };
1879
2000
 
@@ -1995,19 +2116,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
1995
2116
  parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
1996
2117
  };
1997
2118
 
1998
- auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
1999
- return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
2119
+ auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
2120
+ return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
2000
2121
  };
2001
2122
 
2002
- auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
2123
+ auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
2003
2124
  // For large number of rows, 128 invocations seems to work best.
2004
2125
  // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
2005
2126
  // can't use 256 for D==80.
2006
2127
  // For scalar, use 128 (arbitrary)
2128
+ // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
2129
+ const uint32_t D = (hsk|hsv);
2007
2130
  uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
2008
2131
  ? scalar_flash_attention_workgroup_size
2009
2132
  : ((small_rows && (D % 32) == 0) ? 256 : 128);
2010
- auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
2133
+ auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
2011
2134
 
2012
2135
  // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
2013
2136
  // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@@ -2016,26 +2139,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
2016
2139
 
2017
2140
  // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
2018
2141
  GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
2019
- return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
2142
+ return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
2020
2143
  };
2021
2144
 
2022
- #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
2023
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2024
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2025
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2026
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2027
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2028
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2029
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2030
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2145
+ #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
2146
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2147
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2148
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2149
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2150
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2151
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2152
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2153
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2031
2154
 
2032
2155
  #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
2033
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
2034
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
2035
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
2036
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
2037
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
2038
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
2156
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
2157
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
2158
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
2159
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
2160
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
2161
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
2162
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
2163
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
2164
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
2039
2165
 
2040
2166
  CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
2041
2167
  CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
@@ -2625,7 +2751,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2625
2751
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2626
2752
 
2627
2753
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2628
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
2754
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
2629
2755
  ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2630
2756
 
2631
2757
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -2639,7 +2765,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2639
2765
 
2640
2766
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2641
2767
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2642
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2768
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2769
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
2643
2770
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2644
2771
  ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2645
2772
 
@@ -2656,19 +2783,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
2656
2783
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2657
2784
 
2658
2785
  if (device->float_controls_rte_fp16) {
2659
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2660
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2661
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2662
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2663
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2664
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2786
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2787
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2788
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2789
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2790
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2791
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2792
+ } else {
2793
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2794
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2795
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2796
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2797
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2798
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2799
+ }
2800
+
2801
+ if (device->float_controls_rte_fp16) {
2802
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2803
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2804
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2805
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2806
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2807
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2808
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2809
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2810
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2665
2811
  } else {
2666
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2667
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2668
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2669
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2670
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2671
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2812
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2813
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2814
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2815
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2816
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2817
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2818
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2819
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2820
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2672
2821
  }
2673
2822
 
2674
2823
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
@@ -2708,7 +2857,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
2708
2857
  ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2709
2858
  ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2710
2859
 
2711
- ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
2860
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
2861
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
2862
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
2712
2863
 
2713
2864
  ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2714
2865
 
@@ -2720,6 +2871,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2720
2871
 
2721
2872
  ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2722
2873
 
2874
+ ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2875
+
2723
2876
  ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2724
2877
  ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2725
2878
 
@@ -2728,6 +2881,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2728
2881
  ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2729
2882
 
2730
2883
  CREATE_UNARY(gelu)
2884
+ CREATE_UNARY(gelu_erf)
2731
2885
  CREATE_UNARY(gelu_quick)
2732
2886
  CREATE_UNARY(silu)
2733
2887
  CREATE_UNARY(relu)
@@ -2735,6 +2889,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
2735
2889
  CREATE_UNARY(sigmoid)
2736
2890
  #undef CREATE_UNARY
2737
2891
 
2892
+ #define CREATE_GLU(name) \
2893
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2894
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
2895
+
2896
+ CREATE_GLU(geglu)
2897
+ CREATE_GLU(reglu)
2898
+ CREATE_GLU(swiglu)
2899
+ CREATE_GLU(geglu_erf)
2900
+ CREATE_GLU(geglu_quick)
2901
+ #undef CREATE_GLU
2902
+
2738
2903
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2739
2904
  ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2740
2905
 
@@ -3415,6 +3580,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
3415
3580
 
3416
3581
  device->idx = idx;
3417
3582
 
3583
+ device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
3584
+
3418
3585
  return device;
3419
3586
  }
3420
3587
 
@@ -3561,6 +3728,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
3561
3728
  static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
3562
3729
  static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
3563
3730
 
3731
+ static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
3732
+
3564
3733
  static void ggml_vk_instance_init() {
3565
3734
  if (vk_instance_initialized) {
3566
3735
  return;
@@ -3581,7 +3750,7 @@ static void ggml_vk_instance_init() {
3581
3750
  #ifdef __APPLE__
3582
3751
  const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
3583
3752
  #endif
3584
-
3753
+ const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr;
3585
3754
  std::vector<const char*> layers;
3586
3755
 
3587
3756
  if (validation_ext) {
@@ -3596,6 +3765,9 @@ static void ggml_vk_instance_init() {
3596
3765
  extensions.push_back("VK_KHR_portability_enumeration");
3597
3766
  }
3598
3767
  #endif
3768
+ if (debug_utils_ext) {
3769
+ extensions.push_back("VK_EXT_debug_utils");
3770
+ }
3599
3771
  vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
3600
3772
  #ifdef __APPLE__
3601
3773
  if (portability_enumeration_ext) {
@@ -3619,6 +3791,17 @@ static void ggml_vk_instance_init() {
3619
3791
  vk_instance.instance = vk::createInstance(instance_create_info);
3620
3792
  vk_instance_initialized = true;
3621
3793
 
3794
+ if (debug_utils_ext) {
3795
+ vk_instance.debug_utils_support = true;
3796
+ vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT");
3797
+ vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT");
3798
+ vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT");
3799
+ vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
3800
+ vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
3801
+ vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
3802
+
3803
+ }
3804
+
3622
3805
  vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
3623
3806
 
3624
3807
  // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
@@ -4091,6 +4274,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
4091
4274
  return nullptr;
4092
4275
  }
4093
4276
 
4277
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
4094
4278
  device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
4095
4279
 
4096
4280
  return buf->ptr;
@@ -4101,6 +4285,8 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
4101
4285
  return;
4102
4286
  }
4103
4287
  VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
4288
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
4289
+
4104
4290
  vk_buffer buf;
4105
4291
  size_t index;
4106
4292
  for (size_t i = 0; i < device->pinned_memory.size(); i++) {
@@ -4123,6 +4309,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
4123
4309
  }
4124
4310
 
4125
4311
  static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
4312
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
4126
4313
  buf = nullptr;
4127
4314
  buf_offset = 0;
4128
4315
  for (size_t i = 0; i < device->pinned_memory.size(); i++) {
@@ -4424,7 +4611,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
4424
4611
  memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
4425
4612
  }
4426
4613
  } else {
4427
- std::lock_guard<std::mutex> guard(dst->device->mutex);
4614
+ std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
4428
4615
 
4429
4616
  vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
4430
4617
  ggml_vk_ctx_begin(dst->device, subctx);
@@ -4515,7 +4702,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
4515
4702
 
4516
4703
  memcpy(dst, (uint8_t *) src->ptr + offset, size);
4517
4704
  } else {
4518
- std::lock_guard<std::mutex> guard(src->device->mutex);
4705
+ std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
4519
4706
 
4520
4707
  vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
4521
4708
  ggml_vk_ctx_begin(src->device, subctx);
@@ -4545,7 +4732,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds
4545
4732
 
4546
4733
  static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
4547
4734
  if (src->device == dst->device) {
4548
- std::lock_guard<std::mutex> guard(src->device->mutex);
4735
+ std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
4549
4736
  VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
4550
4737
  // Copy within the device
4551
4738
  vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
@@ -4580,7 +4767,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t
4580
4767
  static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
4581
4768
  VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
4582
4769
 
4583
- std::lock_guard<std::mutex> guard(dst->device->mutex);
4770
+ std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
4584
4771
  vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
4585
4772
  ggml_vk_ctx_begin(dst->device, subctx);
4586
4773
  subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
@@ -4807,9 +4994,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
4807
4994
  // type size must be exactly 2 or 4.
4808
4995
  GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
4809
4996
  if ((ggml_type_size(src->type) % 4) == 0) {
4810
- return ctx->device->pipeline_contig_cpy_f32_f32;
4997
+ if (contig) {
4998
+ return ctx->device->pipeline_contig_cpy_f32_f32;
4999
+ } else {
5000
+ return ctx->device->pipeline_cpy_f32_f32;
5001
+ }
4811
5002
  } else {
4812
- return ctx->device->pipeline_contig_cpy_f16_f16;
5003
+ if (contig) {
5004
+ return ctx->device->pipeline_contig_cpy_f16_f16;
5005
+ } else {
5006
+ return ctx->device->pipeline_cpy_f16_f16;
5007
+ }
4813
5008
  }
4814
5009
  }
4815
5010
 
@@ -4870,7 +5065,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4870
5065
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
4871
5066
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
4872
5067
  std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
4873
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
5068
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
4874
5069
  GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
4875
5070
 
4876
5071
  const uint64_t ne00 = src0->ne[0];
@@ -5098,7 +5293,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
5098
5293
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
5099
5294
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
5100
5295
  std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)");
5101
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
5296
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
5102
5297
  GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
5103
5298
 
5104
5299
  const uint64_t ne00 = src0->ne[0];
@@ -5699,7 +5894,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
5699
5894
  std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
5700
5895
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
5701
5896
  std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
5702
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
5897
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
5703
5898
  GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
5704
5899
  GGML_ASSERT(ids->type == GGML_TYPE_I32);
5705
5900
 
@@ -5893,14 +6088,60 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
5893
6088
  if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
5894
6089
  ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
5895
6090
  } else {
5896
- ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
6091
+ // Split based on number of ids, to fit in shared memory
6092
+ const uint32_t nei0 = (uint32_t)src2->ne[0];
6093
+ const uint32_t nei1 = (uint32_t)src2->ne[1];
6094
+
6095
+ GGML_ASSERT(nei0 <= 4096);
6096
+ const uint32_t split_size = std::min(nei1, 4096u / nei0);
6097
+
6098
+ ggml_tensor src1_copy = *src1;
6099
+ ggml_tensor src2_copy = *src2;
6100
+ ggml_tensor dst_copy = *dst;
6101
+
6102
+ for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
6103
+ const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
6104
+
6105
+ src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
6106
+ src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
6107
+ dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
6108
+
6109
+ src1_copy.ne[2] = n_tokens;
6110
+ src2_copy.ne[1] = n_tokens;
6111
+ dst_copy.ne[2] = n_tokens;
6112
+
6113
+ ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
6114
+ }
5897
6115
  }
5898
6116
  }
5899
6117
 
5900
- static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
6118
+ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
5901
6119
  // Needs to be kept up to date on shader changes
6120
+ GGML_UNUSED(hsv);
5902
6121
  const uint32_t wg_size = scalar_flash_attention_workgroup_size;
5903
- const uint32_t Br = scalar_flash_attention_num_large_rows;
6122
+ const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
6123
+ const uint32_t Bc = scalar_flash_attention_Bc;
6124
+
6125
+ const uint32_t tmpsh = wg_size * sizeof(float);
6126
+ const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
6127
+
6128
+ const uint32_t masksh = Bc * Br * sizeof(float);
6129
+
6130
+ const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
6131
+
6132
+ const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
6133
+ const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
6134
+
6135
+ VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
6136
+
6137
+ return supported;
6138
+ }
6139
+
6140
+ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
6141
+ // Needs to be kept up to date on shader changes
6142
+ GGML_UNUSED(hsv);
6143
+ const uint32_t wg_size = scalar_flash_attention_workgroup_size;
6144
+ const uint32_t Br = coopmat1_flash_attention_num_large_rows;
5904
6145
  const uint32_t Bc = scalar_flash_attention_Bc;
5905
6146
 
5906
6147
  const uint32_t acctype = f32acc ? 4 : 2;
@@ -5909,12 +6150,12 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
5909
6150
  const uint32_t tmpsh = wg_size * sizeof(float);
5910
6151
  const uint32_t tmpshv4 = wg_size * 4 * acctype;
5911
6152
 
5912
- const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
6153
+ const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
5913
6154
 
5914
- const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
6155
+ const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
5915
6156
  const uint32_t sfsh = Bc * sfshstride * acctype;
5916
6157
 
5917
- const uint32_t kshstride = D / 4 + 2;
6158
+ const uint32_t kshstride = hsk / 4 + 2;
5918
6159
  const uint32_t ksh = Bc * kshstride * f16vec4;
5919
6160
 
5920
6161
  const uint32_t slope = Br * sizeof(float);
@@ -5922,7 +6163,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
5922
6163
  const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
5923
6164
  const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
5924
6165
 
5925
- VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
6166
+ VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
5926
6167
 
5927
6168
  return supported;
5928
6169
  }
@@ -5944,13 +6185,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5944
6185
  GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
5945
6186
 
5946
6187
  const uint32_t nem1 = mask ? mask->ne[1] : 0;
5947
- const uint32_t nbm1 = mask ? mask->nb[1] : 0;
6188
+ const uint32_t nem2 = mask ? mask->ne[2] : 0;
6189
+ const uint32_t nem3 = mask ? mask->ne[3] : 0;
5948
6190
 
5949
- const uint32_t D = neq0;
6191
+ const uint32_t HSK = nek0;
6192
+ const uint32_t HSV = nev0;
5950
6193
  uint32_t N = neq1;
5951
6194
  const uint32_t KV = nek1;
5952
6195
 
5953
- GGML_ASSERT(ne0 == D);
6196
+ GGML_ASSERT(ne0 == HSV);
5954
6197
  GGML_ASSERT(ne2 == N);
5955
6198
 
5956
6199
  // input tensor rows must be contiguous
@@ -5958,12 +6201,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5958
6201
  GGML_ASSERT(nbk0 == ggml_type_size(k->type));
5959
6202
  GGML_ASSERT(nbv0 == ggml_type_size(v->type));
5960
6203
 
5961
- GGML_ASSERT(neq0 == D);
5962
- GGML_ASSERT(nek0 == D);
5963
- GGML_ASSERT(nev0 == D);
6204
+ GGML_ASSERT(neq0 == HSK);
5964
6205
 
5965
6206
  GGML_ASSERT(neq1 == N);
5966
- GGML_ASSERT(nev0 == D);
5967
6207
 
5968
6208
  GGML_ASSERT(nev1 == nek1);
5969
6209
 
@@ -5984,7 +6224,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5984
6224
  const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
5985
6225
  (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
5986
6226
 
5987
- const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
6227
+ const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
5988
6228
 
5989
6229
  if (!coopmat_shape_supported || !coopmat_shmem_supported) {
5990
6230
  path = FA_SCALAR;
@@ -6004,7 +6244,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6004
6244
  case FA_SCALAR:
6005
6245
  case FA_COOPMAT1:
6006
6246
  // We may switch from coopmat1 to scalar, so use the scalar limit for both
6007
- max_gqa = scalar_flash_attention_num_large_rows;
6247
+ max_gqa = get_fa_scalar_num_large_rows(HSV);
6008
6248
  break;
6009
6249
  case FA_COOPMAT2:
6010
6250
  max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
@@ -6014,7 +6254,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6014
6254
  }
6015
6255
 
6016
6256
  if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
6017
- qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
6257
+ qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
6018
6258
  // grouped query attention - make the N dimension equal to gqa_ratio, reduce
6019
6259
  // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
6020
6260
  // and change addressing calculations to index Q's dimension 2.
@@ -6037,47 +6277,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6037
6277
  path = FA_SCALAR;
6038
6278
  }
6039
6279
 
6280
+ // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
6281
+ if (path == FA_SCALAR &&
6282
+ !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
6283
+ small_rows = true;
6284
+ }
6285
+
6040
6286
  bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
6041
6287
 
6288
+ FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
6289
+
6042
6290
  switch (path) {
6043
6291
  case FA_SCALAR:
6044
- switch (D) {
6045
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
6046
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
6047
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
6048
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
6049
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
6050
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
6051
- default:
6052
- GGML_ASSERT(!"unsupported D value");
6053
- return;
6054
- }
6292
+ pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
6055
6293
  break;
6056
6294
  case FA_COOPMAT1:
6057
- switch (D) {
6058
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
6059
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
6060
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
6061
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
6062
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
6063
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
6064
- default:
6065
- GGML_ASSERT(!"unsupported D value");
6066
- return;
6067
- }
6295
+ pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
6068
6296
  break;
6069
6297
  case FA_COOPMAT2:
6070
- switch (D) {
6071
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
6072
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
6073
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
6074
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
6075
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
6076
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
6077
- default:
6078
- GGML_ASSERT(!"unsupported D value");
6079
- return;
6080
- }
6298
+ pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
6081
6299
  break;
6082
6300
  default:
6083
6301
  GGML_ASSERT(0);
@@ -6105,21 +6323,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6105
6323
  const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
6106
6324
 
6107
6325
  // Try to use split_k when KV is large enough to be worth the overhead
6108
- if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
6326
+ if (workgroups_x == 1 && shader_core_count > 0) {
6109
6327
  // Try to run two workgroups per SM.
6110
- split_k = ctx->device->shader_core_count * 2 / workgroups_y;
6328
+ split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
6111
6329
  if (split_k > 1) {
6112
6330
  // Try to evenly split KV into split_k chunks, but it needs to be a multiple
6113
6331
  // of "align", so recompute split_k based on that.
6114
- split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
6332
+ split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
6115
6333
  split_k = CEIL_DIV(KV, split_kv);
6116
6334
  workgroups_x = split_k;
6117
6335
  }
6118
6336
  }
6119
6337
 
6120
- // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
6121
- // and the per-row m and L values (ne1 rows).
6122
- const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
6338
+ // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
6339
+ // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
6340
+ const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
6123
6341
  if (split_k_size > ctx->device->max_memory_allocation_size) {
6124
6342
  GGML_ABORT("Requested preallocation size is too large");
6125
6343
  }
@@ -6206,18 +6424,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6206
6424
  }
6207
6425
  }
6208
6426
 
6427
+ uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
6428
+
6209
6429
  const vk_flash_attn_push_constants pc = { N, KV,
6210
6430
  (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
6211
6431
  (uint32_t)neq2, (uint32_t)neq3,
6212
6432
  (uint32_t)nek2, (uint32_t)nek3,
6213
6433
  (uint32_t)nev2, (uint32_t)nev3,
6214
- nem1,
6434
+ nem1, nem2, nem3,
6215
6435
  q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
6216
6436
  k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
6217
6437
  v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
6218
- nbm1,
6219
6438
  scale, max_bias, logit_softcap,
6220
- mask != nullptr, n_head_log2, m0, m1,
6439
+ mask_n_head_log2, m0, m1,
6221
6440
  gqa_ratio, split_kv, split_k };
6222
6441
 
6223
6442
  ggml_vk_sync_buffers(subctx);
@@ -6238,13 +6457,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6238
6457
  pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
6239
6458
 
6240
6459
  ggml_vk_sync_buffers(subctx);
6241
- const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
6460
+ const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
6242
6461
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
6243
6462
  {
6244
6463
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6245
6464
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6246
6465
  },
6247
- pc2, { (uint32_t)ne1, 1, 1 });
6466
+ pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
6248
6467
  } else {
6249
6468
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6250
6469
  {
@@ -6320,8 +6539,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6320
6539
  }
6321
6540
  return nullptr;
6322
6541
  case GGML_OP_UPSCALE:
6323
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
6324
- return ctx->device->pipeline_upscale_f32;
6542
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6543
+ int mode = ggml_get_op_params_i32(dst, 0);
6544
+ switch (mode) {
6545
+ case GGML_SCALE_MODE_NEAREST:
6546
+ return ctx->device->pipeline_upscale_nearest_f32;
6547
+ case GGML_SCALE_MODE_BILINEAR:
6548
+ return ctx->device->pipeline_upscale_bilinear_f32;
6549
+ case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
6550
+ return ctx->device->pipeline_upscale_bilinear_ac_f32;
6551
+ }
6325
6552
  }
6326
6553
  return nullptr;
6327
6554
  case GGML_OP_SCALE:
@@ -6354,6 +6581,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6354
6581
  return ctx->device->pipeline_pad_f32;
6355
6582
  }
6356
6583
  return nullptr;
6584
+ case GGML_OP_ROLL:
6585
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6586
+ return ctx->device->pipeline_roll_f32;
6587
+ }
6588
+ return nullptr;
6357
6589
  case GGML_OP_REPEAT:
6358
6590
  if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
6359
6591
  return ctx->device->pipeline_repeat_f32;
@@ -6368,6 +6600,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6368
6600
  case GGML_OP_CONT:
6369
6601
  case GGML_OP_DUP:
6370
6602
  return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
6603
+ case GGML_OP_SET_ROWS:
6604
+ return ctx->device->pipeline_set_rows[dst->type];
6371
6605
  case GGML_OP_SILU_BACK:
6372
6606
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6373
6607
  return ctx->device->pipeline_silu_back_f32;
@@ -6385,7 +6619,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6385
6619
  return nullptr;
6386
6620
  case GGML_OP_RMS_NORM:
6387
6621
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6388
- return ctx->device->pipeline_rms_norm_f32;
6622
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
6389
6623
  }
6390
6624
  return nullptr;
6391
6625
  case GGML_OP_RMS_NORM_BACK:
@@ -6410,6 +6644,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6410
6644
  return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
6411
6645
  case GGML_UNARY_OP_GELU:
6412
6646
  return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
6647
+ case GGML_UNARY_OP_GELU_ERF:
6648
+ return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
6413
6649
  case GGML_UNARY_OP_GELU_QUICK:
6414
6650
  return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
6415
6651
  case GGML_UNARY_OP_RELU:
@@ -6422,6 +6658,28 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6422
6658
  break;
6423
6659
  }
6424
6660
  return nullptr;
6661
+ case GGML_OP_GLU:
6662
+ if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6663
+ (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
6664
+ (src0->type != dst->type)) {
6665
+ return nullptr;
6666
+ }
6667
+
6668
+ switch (ggml_get_glu_op(dst)) {
6669
+ case GGML_GLU_OP_GEGLU:
6670
+ return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
6671
+ case GGML_GLU_OP_REGLU:
6672
+ return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
6673
+ case GGML_GLU_OP_SWIGLU:
6674
+ return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
6675
+ case GGML_GLU_OP_GEGLU_ERF:
6676
+ return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
6677
+ case GGML_GLU_OP_GEGLU_QUICK:
6678
+ return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
6679
+ default:
6680
+ break;
6681
+ }
6682
+ return nullptr;
6425
6683
  case GGML_OP_DIAG_MASK_INF:
6426
6684
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6427
6685
  return ctx->device->pipeline_diag_mask_inf_f32;
@@ -6582,6 +6840,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
6582
6840
  case GGML_OP_RMS_NORM:
6583
6841
  case GGML_OP_CONV_2D_DW:
6584
6842
  case GGML_OP_IM2COL:
6843
+ case GGML_OP_SET_ROWS:
6585
6844
  return true;
6586
6845
  default:
6587
6846
  return false;
@@ -6876,12 +7135,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6876
7135
  case GGML_OP_COS:
6877
7136
  case GGML_OP_CLAMP:
6878
7137
  case GGML_OP_PAD:
7138
+ case GGML_OP_ROLL:
6879
7139
  case GGML_OP_REPEAT:
6880
7140
  case GGML_OP_REPEAT_BACK:
6881
7141
  case GGML_OP_CPY:
6882
7142
  case GGML_OP_CONCAT:
6883
7143
  case GGML_OP_UPSCALE:
6884
7144
  case GGML_OP_UNARY:
7145
+ case GGML_OP_GLU:
6885
7146
  case GGML_OP_CONV_2D_DW:
6886
7147
  {
6887
7148
  uint32_t ne = ggml_nelements(dst);
@@ -6894,6 +7155,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6894
7155
  ne *= ggml_type_size(src0->type) / 2;
6895
7156
  }
6896
7157
  }
7158
+ // copy_to_quant has block size of 32, and each thread does QUANT_K elements.
7159
+ // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
7160
+ // So divide by block size here before splitting into 512x512 groups.
7161
+ if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
7162
+ ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
7163
+ }
6897
7164
  if (ne > 262144) {
6898
7165
  elements = { 512, 512, CEIL_DIV(ne, 262144) };
6899
7166
  } else if (ne > 512) {
@@ -6902,6 +7169,25 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6902
7169
  elements = { ne, 1, 1 };
6903
7170
  }
6904
7171
  } break;
7172
+ case GGML_OP_SET_ROWS:
7173
+ {
7174
+ uint32_t ne = ggml_nelements(src0);
7175
+ if (ggml_is_quantized(dst->type)) {
7176
+ // quants run 32 threads each doing QUANT_K elements
7177
+ ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
7178
+ } else {
7179
+ // scalar types do one element per thread, running 512 threads
7180
+ ne = CEIL_DIV(ne, 512);
7181
+ }
7182
+ if (ne > 262144) {
7183
+ elements = { 512, 512, CEIL_DIV(ne, 262144) };
7184
+ } else if (ne > 512) {
7185
+ elements = { 512, CEIL_DIV(ne, 512), 1 };
7186
+ } else {
7187
+ elements = { ne, 1, 1 };
7188
+ }
7189
+ }
7190
+ break;
6905
7191
  default:
6906
7192
  elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
6907
7193
  break;
@@ -6922,7 +7208,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6922
7208
  }
6923
7209
  }
6924
7210
 
6925
- if (op == GGML_OP_SOFT_MAX) {
7211
+ if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
6926
7212
  // Empty src1 is possible in soft_max, but the shader needs a buffer
6927
7213
  vk_subbuffer subbuf_y;
6928
7214
  if (use_src1) {
@@ -7311,14 +7597,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
7311
7597
 
7312
7598
  static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7313
7599
  const uint32_t src0_type_size = ggml_type_size(src0->type);
7600
+ const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
7314
7601
 
7315
- const float sf0 = (float)dst->ne[0] / src0->ne[0];
7316
- const float sf1 = (float)dst->ne[1] / src0->ne[1];
7317
- const float sf2 = (float)dst->ne[2] / src0->ne[2];
7318
- const float sf3 = (float)dst->ne[3] / src0->ne[3];
7602
+ float sf0 = (float)dst->ne[0] / src0->ne[0];
7603
+ float sf1 = (float)dst->ne[1] / src0->ne[1];
7604
+ float sf2 = (float)dst->ne[2] / src0->ne[2];
7605
+ float sf3 = (float)dst->ne[3] / src0->ne[3];
7606
+
7607
+ if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7608
+ sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
7609
+ sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
7610
+ }
7319
7611
 
7320
7612
  ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
7321
7613
  (uint32_t)ggml_nelements(dst), 0, 0,
7614
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
7322
7615
  (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7323
7616
  (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
7324
7617
  sf0, sf1, sf2, sf3,
@@ -7326,123 +7619,64 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
7326
7619
  }
7327
7620
 
7328
7621
  static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7329
- float * op_params = (float *)dst->op_params;
7330
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7331
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7622
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7623
+ p.param1 = ggml_get_op_params_f32(dst, 0);
7624
+ p.param2 = ggml_get_op_params_f32(dst, 1);
7332
7625
 
7333
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
7334
- (uint32_t)ggml_nelements(src0),
7335
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7336
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7337
- 0,
7338
- op_params[0], 0.0f,
7339
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7340
- }, dryrun);
7626
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
7341
7627
  }
7342
7628
 
7343
7629
  static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7344
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7345
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7346
-
7347
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
7348
- (uint32_t)ggml_nelements(src0),
7349
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7350
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7351
- 0,
7352
- 0.0f, 0.0f,
7353
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7354
- }, dryrun);
7630
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
7355
7631
  }
7356
7632
 
7357
7633
  static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7358
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7359
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7360
-
7361
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
7362
- (uint32_t)ggml_nelements(src0),
7363
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7364
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7365
- 0,
7366
- 0.0f, 0.0f,
7367
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7368
- }, dryrun);
7634
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
7369
7635
  }
7370
7636
 
7371
7637
  static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7372
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7373
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7374
-
7375
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
7376
- (uint32_t)ggml_nelements(src0),
7377
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7378
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7379
- 0,
7380
- 0.0f, 0.0f,
7381
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7382
- }, dryrun);
7638
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
7383
7639
  }
7384
7640
 
7385
7641
  static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7386
- float * op_params = (float *)dst->op_params;
7387
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7388
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7642
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7643
+ p.param1 = ggml_get_op_params_f32(dst, 0);
7644
+ p.param2 = ggml_get_op_params_f32(dst, 1);
7389
7645
 
7390
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
7391
- (uint32_t)ggml_nelements(src0),
7392
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7393
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7394
- 0,
7395
- op_params[0], op_params[1],
7396
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7397
- }, dryrun);
7646
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
7398
7647
  }
7399
7648
 
7400
7649
  static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7401
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7402
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7650
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7651
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
7652
+ }
7403
7653
 
7404
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
7405
- (uint32_t)ggml_nelements(dst),
7406
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7407
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7408
- 0,
7409
- 0.0f, 0.0f,
7410
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7411
- }, dryrun);
7654
+ static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7655
+ const int32_t s0 = ggml_get_op_params_i32(dst, 0);
7656
+ const int32_t s1 = ggml_get_op_params_i32(dst, 1);
7657
+ const int32_t s2 = ggml_get_op_params_i32(dst, 2);
7658
+ const int32_t s3 = ggml_get_op_params_i32(dst, 3);
7659
+ const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
7660
+ const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
7661
+
7662
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7663
+ memcpy(&p.param1, &s01_packed, sizeof(float));
7664
+ memcpy(&p.param2, &s23_packed, sizeof(float));
7665
+
7666
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
7412
7667
  }
7413
7668
 
7414
7669
  static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7415
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7416
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7417
-
7418
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
7419
- (uint32_t)ggml_nelements(dst),
7420
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7421
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7422
- 0,
7423
- 0.0f, 0.0f,
7424
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7425
- }, dryrun);
7670
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7671
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
7426
7672
  }
7427
7673
 
7428
7674
  static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7429
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7430
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7431
-
7432
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
7433
- (uint32_t)ggml_nelements(dst),
7434
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7435
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7436
- 0,
7437
- 0.0f, 0.0f,
7438
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7439
- }, dryrun);
7675
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7676
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
7440
7677
  }
7441
7678
 
7442
7679
  static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7443
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7444
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7445
-
7446
7680
  uint32_t ne = (uint32_t)ggml_nelements(src0);
7447
7681
  if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
7448
7682
  // Convert from number of logical elements to 2- or 4-byte units.
@@ -7454,13 +7688,22 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
7454
7688
  }
7455
7689
  }
7456
7690
 
7457
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
7458
- ne,
7459
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7460
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7691
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
7692
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
7693
+ }
7694
+
7695
+ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7696
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7697
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
7698
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7699
+
7700
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
7701
+ (uint32_t)ggml_nelements(src0),
7702
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7703
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7704
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7461
7705
  0,
7462
- 0.0f, 0.0f,
7463
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7706
+ 0.0f, 0.0f, 0,
7464
7707
  }, dryrun);
7465
7708
  }
7466
7709
 
@@ -7485,18 +7728,18 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
7485
7728
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
7486
7729
  }
7487
7730
 
7488
- static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7489
- float * op_params = (float *)dst->op_params;
7731
+ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
7490
7732
  const uint32_t src0_type_size = ggml_type_size(src0->type);
7733
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
7491
7734
  const uint32_t dst_type_size = ggml_type_size(dst->type);
7492
7735
 
7493
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
7736
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
7494
7737
  (uint32_t)ggml_nelements(src0),
7495
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7496
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7738
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7739
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7740
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7497
7741
  0,
7498
- op_params[0], 0.0f,
7499
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7742
+ op_params[0], 0.0f, 0,
7500
7743
  }, dryrun);
7501
7744
  }
7502
7745
 
@@ -7514,6 +7757,25 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
7514
7757
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7515
7758
  }
7516
7759
 
7760
+ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7761
+ const bool swapped = (bool)dst->op_params[1];
7762
+ const bool split = src1 != nullptr;
7763
+
7764
+ GGML_ASSERT(ggml_is_contiguous(src0));
7765
+
7766
+ if (!split) {
7767
+ GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7768
+ } else {
7769
+ GGML_ASSERT(src0->ne[0] == src1->ne[0]);
7770
+ GGML_ASSERT(src0->ne[0] == dst->ne[0]);
7771
+ GGML_ASSERT(src0->type == src1->type);
7772
+ }
7773
+
7774
+ const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
7775
+
7776
+ ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
7777
+ }
7778
+
7517
7779
  static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7518
7780
  int32_t * op_params = (int32_t *)dst->op_params;
7519
7781
  ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
@@ -7529,7 +7791,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7529
7791
  const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
7530
7792
  const uint32_t nrows_y = (uint32_t)src0->ne[1];
7531
7793
 
7532
- const uint32_t n_head_kv = nrows_x/nrows_y;
7794
+ const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
7795
+ const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
7796
+ const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
7797
+ const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
7798
+ const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
7799
+
7800
+ const uint32_t n_head_kv = src0->ne[2];
7533
7801
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
7534
7802
 
7535
7803
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -7538,6 +7806,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7538
7806
  ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
7539
7807
  ncols,
7540
7808
  src1 != nullptr ? nrows_y : (uint32_t)0,
7809
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
7810
+ ne12, ne13,
7811
+ nb11, nb12, nb13,
7541
7812
  scale, max_bias,
7542
7813
  m0, m1,
7543
7814
  n_head_log2,
@@ -8687,11 +8958,12 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
8687
8958
  }
8688
8959
  }
8689
8960
 
8690
- static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
8961
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
8691
8962
 
8692
8963
  // Returns true if node has enqueued work into the queue, false otherwise
8693
8964
  // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
8694
- static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
8965
+ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
8966
+ ggml_tensor * node = cgraph->nodes[node_idx];
8695
8967
  if (ggml_is_empty(node) || !node->buffer) {
8696
8968
  return false;
8697
8969
  }
@@ -8716,6 +8988,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8716
8988
  switch (ggml_get_unary_op(node)) {
8717
8989
  case GGML_UNARY_OP_SILU:
8718
8990
  case GGML_UNARY_OP_GELU:
8991
+ case GGML_UNARY_OP_GELU_ERF:
8719
8992
  case GGML_UNARY_OP_GELU_QUICK:
8720
8993
  case GGML_UNARY_OP_RELU:
8721
8994
  case GGML_UNARY_OP_TANH:
@@ -8725,6 +8998,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8725
8998
  return false;
8726
8999
  }
8727
9000
  break;
9001
+ case GGML_OP_GLU:
9002
+ switch (ggml_get_glu_op(node)) {
9003
+ case GGML_GLU_OP_GEGLU:
9004
+ case GGML_GLU_OP_REGLU:
9005
+ case GGML_GLU_OP_SWIGLU:
9006
+ case GGML_GLU_OP_GEGLU_ERF:
9007
+ case GGML_GLU_OP_GEGLU_QUICK:
9008
+ break;
9009
+ default:
9010
+ return false;
9011
+ }
9012
+ break;
8728
9013
  case GGML_OP_REPEAT:
8729
9014
  case GGML_OP_REPEAT_BACK:
8730
9015
  case GGML_OP_GET_ROWS:
@@ -8741,7 +9026,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8741
9026
  case GGML_OP_COS:
8742
9027
  case GGML_OP_CLAMP:
8743
9028
  case GGML_OP_PAD:
9029
+ case GGML_OP_ROLL:
8744
9030
  case GGML_OP_CPY:
9031
+ case GGML_OP_SET_ROWS:
8745
9032
  case GGML_OP_CONT:
8746
9033
  case GGML_OP_DUP:
8747
9034
  case GGML_OP_SILU_BACK:
@@ -8808,6 +9095,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8808
9095
  case GGML_OP_CLAMP:
8809
9096
  case GGML_OP_PAD:
8810
9097
  case GGML_OP_CPY:
9098
+ case GGML_OP_SET_ROWS:
8811
9099
  case GGML_OP_CONT:
8812
9100
  case GGML_OP_DUP:
8813
9101
  case GGML_OP_SILU_BACK:
@@ -8817,6 +9105,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8817
9105
  case GGML_OP_RMS_NORM_BACK:
8818
9106
  case GGML_OP_L2_NORM:
8819
9107
  case GGML_OP_UNARY:
9108
+ case GGML_OP_GLU:
8820
9109
  case GGML_OP_DIAG_MASK_INF:
8821
9110
  case GGML_OP_SOFT_MAX:
8822
9111
  case GGML_OP_SOFT_MAX_BACK:
@@ -8909,12 +9198,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8909
9198
  case GGML_OP_PAD:
8910
9199
  ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
8911
9200
 
9201
+ break;
9202
+ case GGML_OP_ROLL:
9203
+ ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
9204
+
8912
9205
  break;
8913
9206
  case GGML_OP_CPY:
8914
9207
  case GGML_OP_CONT:
8915
9208
  case GGML_OP_DUP:
8916
9209
  ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
8917
9210
 
9211
+ break;
9212
+ case GGML_OP_SET_ROWS:
9213
+ ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun);
9214
+
8918
9215
  break;
8919
9216
  case GGML_OP_SILU_BACK:
8920
9217
  ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -8929,8 +9226,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8929
9226
 
8930
9227
  break;
8931
9228
  case GGML_OP_RMS_NORM:
8932
- ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
8933
-
9229
+ if (ctx->num_additional_fused_ops > 0) {
9230
+ // fused rms_norm + mul
9231
+ ggml_tensor *mul = cgraph->nodes[node_idx + 1];
9232
+ ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
9233
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
9234
+ } else {
9235
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
9236
+ }
8934
9237
  break;
8935
9238
  case GGML_OP_RMS_NORM_BACK:
8936
9239
  ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -8944,6 +9247,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8944
9247
  switch (ggml_get_unary_op(node)) {
8945
9248
  case GGML_UNARY_OP_SILU:
8946
9249
  case GGML_UNARY_OP_GELU:
9250
+ case GGML_UNARY_OP_GELU_ERF:
8947
9251
  case GGML_UNARY_OP_GELU_QUICK:
8948
9252
  case GGML_UNARY_OP_RELU:
8949
9253
  case GGML_UNARY_OP_TANH:
@@ -8954,6 +9258,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8954
9258
  return false;
8955
9259
  }
8956
9260
  break;
9261
+ case GGML_OP_GLU:
9262
+ switch (ggml_get_glu_op(node)) {
9263
+ case GGML_GLU_OP_GEGLU:
9264
+ case GGML_GLU_OP_REGLU:
9265
+ case GGML_GLU_OP_SWIGLU:
9266
+ case GGML_GLU_OP_GEGLU_ERF:
9267
+ case GGML_GLU_OP_GEGLU_QUICK:
9268
+ ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
9269
+ break;
9270
+ default:
9271
+ return false;
9272
+ }
9273
+ break;
8957
9274
  case GGML_OP_DIAG_MASK_INF:
8958
9275
  ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
8959
9276
 
@@ -9075,12 +9392,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
9075
9392
 
9076
9393
  ctx->compute_ctx.reset();
9077
9394
 
9078
- bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
9395
+ bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
9079
9396
  if (!ok) {
9080
9397
  if (node->op == GGML_OP_UNARY) {
9081
9398
  std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
9082
- }
9083
- else {
9399
+ } else if (node->op == GGML_OP_GLU) {
9400
+ std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
9401
+ } else {
9084
9402
  std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
9085
9403
  }
9086
9404
  }
@@ -9089,7 +9407,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
9089
9407
  return true;
9090
9408
  }
9091
9409
 
9092
- static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9410
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9411
+ GGML_UNUSED(cgraph);
9093
9412
  ggml_backend_buffer * buf = nullptr;
9094
9413
 
9095
9414
  switch (tensor->op) {
@@ -9107,7 +9426,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9107
9426
  case GGML_OP_COS:
9108
9427
  case GGML_OP_CLAMP:
9109
9428
  case GGML_OP_PAD:
9429
+ case GGML_OP_ROLL:
9110
9430
  case GGML_OP_CPY:
9431
+ case GGML_OP_SET_ROWS:
9111
9432
  case GGML_OP_CONT:
9112
9433
  case GGML_OP_DUP:
9113
9434
  case GGML_OP_SILU_BACK:
@@ -9149,6 +9470,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9149
9470
  switch (ggml_get_unary_op(tensor)) {
9150
9471
  case GGML_UNARY_OP_SILU:
9151
9472
  case GGML_UNARY_OP_GELU:
9473
+ case GGML_UNARY_OP_GELU_ERF:
9152
9474
  case GGML_UNARY_OP_GELU_QUICK:
9153
9475
  case GGML_UNARY_OP_RELU:
9154
9476
  case GGML_UNARY_OP_TANH:
@@ -9159,6 +9481,19 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9159
9481
  return false;
9160
9482
  }
9161
9483
  break;
9484
+ case GGML_OP_GLU:
9485
+ switch (ggml_get_glu_op(tensor)) {
9486
+ case GGML_GLU_OP_GEGLU:
9487
+ case GGML_GLU_OP_REGLU:
9488
+ case GGML_GLU_OP_SWIGLU:
9489
+ case GGML_GLU_OP_GEGLU_ERF:
9490
+ case GGML_GLU_OP_GEGLU_QUICK:
9491
+ buf = tensor->buffer;
9492
+ break;
9493
+ default:
9494
+ return false;
9495
+ }
9496
+ break;
9162
9497
  case GGML_OP_MUL_MAT:
9163
9498
  case GGML_OP_MUL_MAT_ID:
9164
9499
  case GGML_OP_FLASH_ATTN_EXT:
@@ -9185,7 +9520,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9185
9520
  // Only run if ctx hasn't been submitted yet
9186
9521
  if (!subctx->seqs.empty()) {
9187
9522
  #ifdef GGML_VULKAN_CHECK_RESULTS
9188
- ggml_vk_check_results_0(tensor);
9523
+ ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
9189
9524
  use_fence = true;
9190
9525
  #endif
9191
9526
 
@@ -9205,7 +9540,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9205
9540
  ggml_vk_wait_for_fence(ctx);
9206
9541
  }
9207
9542
  #ifdef GGML_VULKAN_CHECK_RESULTS
9208
- ggml_vk_check_results_1(tensor);
9543
+ ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
9209
9544
  #endif
9210
9545
  }
9211
9546
 
@@ -9652,16 +9987,59 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
9652
9987
  return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
9653
9988
  }
9654
9989
 
9990
+ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
9991
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
9992
+ return false;
9993
+ }
9994
+
9995
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
9996
+ // additional constraints specific to this fusion
9997
+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
9998
+ const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
9999
+
10000
+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
10001
+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
10002
+ // rms_norm only supports f32
10003
+ if (mul->src[0]->type != GGML_TYPE_F32 ||
10004
+ mul->src[1]->type != GGML_TYPE_F32 ||
10005
+ mul->type != GGML_TYPE_F32) {
10006
+ return false;
10007
+ }
10008
+ // if rms_norm is the B operand, then we don't handle broadcast
10009
+ if (rms_norm == mul->src[1] &&
10010
+ mul->src[0]->ne[1] != rms_norm->ne[1]) {
10011
+ return false;
10012
+ }
10013
+ // rms_norm shader assumes contiguous rows
10014
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
10015
+ return false;
10016
+ }
10017
+ }
10018
+ return true;
10019
+ }
10020
+
9655
10021
  static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
9656
10022
  VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
9657
10023
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
9658
10024
 
10025
+ if (vk_instance.debug_utils_support) {
10026
+ vk::DebugUtilsLabelEXT dul = {};
10027
+ dul.pLabelName = "ggml_backend_vk_graph_compute";
10028
+ dul.color = std::array<float,4>{1.0f, 1.0f, 1.0f, 1.0f};
10029
+ vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
10030
+ }
10031
+
9659
10032
  uint64_t total_mat_mul_bytes = 0;
9660
10033
  for (int i = 0; i < cgraph->n_nodes; i++) {
9661
- ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
10034
+ if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10035
+ ctx->num_additional_fused_ops = 1;
10036
+ }
10037
+ ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
9662
10038
  if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
9663
10039
  total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
9664
10040
  }
10041
+ i += ctx->num_additional_fused_ops;
10042
+ ctx->num_additional_fused_ops = 0;
9665
10043
  }
9666
10044
  if (ctx->device->need_compiles) {
9667
10045
  ggml_vk_load_shaders(ctx->device);
@@ -9723,14 +10101,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9723
10101
  mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
9724
10102
  }
9725
10103
 
10104
+ if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10105
+ ctx->num_additional_fused_ops = 1;
10106
+ }
10107
+
9726
10108
  // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
9727
10109
  bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
9728
10110
  bool submit = (submitted_nodes >= nodes_per_submit) ||
9729
10111
  (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
9730
- (i == last_node) ||
10112
+ (i + ctx->num_additional_fused_ops == last_node) ||
9731
10113
  (almost_ready && !ctx->almost_ready_fence_pending);
9732
10114
 
9733
- bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
10115
+ bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
9734
10116
 
9735
10117
  if (vk_perf_logger_enabled) {
9736
10118
  if (ctx->compute_ctx.expired()) {
@@ -9740,7 +10122,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9740
10122
  } else {
9741
10123
  compute_ctx = ctx->compute_ctx.lock();
9742
10124
  }
9743
- compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
10125
+ // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
10126
+ for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
10127
+ compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
10128
+ }
9744
10129
  }
9745
10130
 
9746
10131
  if (enqueued) {
@@ -9762,6 +10147,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9762
10147
  }
9763
10148
  submit_count++;
9764
10149
  }
10150
+ i += ctx->num_additional_fused_ops;
10151
+ ctx->num_additional_fused_ops = 0;
9765
10152
  }
9766
10153
 
9767
10154
  if (vk_perf_logger_enabled) {
@@ -9923,6 +10310,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9923
10310
  case GGML_OP_UNARY:
9924
10311
  switch (ggml_get_unary_op(op)) {
9925
10312
  case GGML_UNARY_OP_GELU:
10313
+ case GGML_UNARY_OP_GELU_ERF:
9926
10314
  case GGML_UNARY_OP_GELU_QUICK:
9927
10315
  case GGML_UNARY_OP_SILU:
9928
10316
  case GGML_UNARY_OP_RELU:
@@ -9936,15 +10324,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9936
10324
  return false;
9937
10325
  }
9938
10326
  break;
10327
+ case GGML_OP_GLU:
10328
+ switch (ggml_get_glu_op(op)) {
10329
+ case GGML_GLU_OP_GEGLU:
10330
+ case GGML_GLU_OP_REGLU:
10331
+ case GGML_GLU_OP_SWIGLU:
10332
+ case GGML_GLU_OP_GEGLU_ERF:
10333
+ case GGML_GLU_OP_GEGLU_QUICK:
10334
+ return ggml_is_contiguous(op->src[0]) &&
10335
+ (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10336
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
10337
+ (op->src[0]->type == op->type);
10338
+ default:
10339
+ return false;
10340
+ }
10341
+ break;
9939
10342
  case GGML_OP_MUL_MAT:
9940
10343
  case GGML_OP_MUL_MAT_ID:
9941
10344
  {
9942
10345
  ggml_type src0_type = op->src[0]->type;
9943
10346
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
9944
10347
  const vk_device& device = ggml_vk_get_device(ctx->device);
9945
- if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
9946
- // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
9947
- return false;
10348
+ if (op->op == GGML_OP_MUL_MAT_ID) {
10349
+ if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
10350
+ // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
10351
+ return false;
10352
+ }
10353
+ // Check against size of shared memory variable
10354
+ if (op->src[2]->ne[0] > 4096) {
10355
+ return false;
10356
+ }
9948
10357
  }
9949
10358
  switch (src0_type) {
9950
10359
  case GGML_TYPE_F32:
@@ -10002,19 +10411,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10002
10411
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
10003
10412
  auto device = ggml_vk_get_device(ctx->device);
10004
10413
  bool coopmat2 = device->coopmat2;
10005
- switch (op->src[0]->ne[0]) {
10006
- case 64:
10007
- case 80:
10008
- case 96:
10009
- case 112:
10010
- case 128:
10011
- case 256:
10012
- break;
10013
- default:
10014
- return false;
10015
- }
10016
- if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
10017
- // different head sizes of K and V are not supported yet
10414
+ FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
10415
+ if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
10018
10416
  return false;
10019
10417
  }
10020
10418
  if (op->src[0]->type != GGML_TYPE_F32) {
@@ -10094,6 +10492,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10094
10492
  return false;
10095
10493
  }
10096
10494
  } break;
10495
+ case GGML_OP_SET_ROWS:
10496
+ {
10497
+ switch (op->type) {
10498
+ case GGML_TYPE_F32:
10499
+ case GGML_TYPE_F16:
10500
+ case GGML_TYPE_BF16:
10501
+ case GGML_TYPE_Q4_0:
10502
+ case GGML_TYPE_Q4_1:
10503
+ case GGML_TYPE_Q5_0:
10504
+ case GGML_TYPE_Q5_1:
10505
+ case GGML_TYPE_Q8_0:
10506
+ case GGML_TYPE_IQ4_NL:
10507
+ return true;
10508
+ default:
10509
+ return false;
10510
+ }
10511
+ } break;
10097
10512
  case GGML_OP_CONT:
10098
10513
  case GGML_OP_CPY:
10099
10514
  case GGML_OP_DUP:
@@ -10178,11 +10593,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10178
10593
  case GGML_OP_CLAMP:
10179
10594
  return op->src[0]->type == GGML_TYPE_F32;
10180
10595
  case GGML_OP_UPSCALE:
10181
- return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
10182
10596
  case GGML_OP_ACC:
10183
10597
  case GGML_OP_CONCAT:
10184
10598
  case GGML_OP_SCALE:
10185
10599
  case GGML_OP_PAD:
10600
+ case GGML_OP_ROLL:
10186
10601
  case GGML_OP_DIAG_MASK_INF:
10187
10602
  case GGML_OP_SOFT_MAX:
10188
10603
  case GGML_OP_SOFT_MAX_BACK:
@@ -10345,6 +10760,22 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
10345
10760
  UNUSED(instance_extensions);
10346
10761
  }
10347
10762
 
10763
+ // Extension availability
10764
+ static bool ggml_vk_instance_debug_utils_ext_available(
10765
+ const std::vector<vk::ExtensionProperties> & instance_extensions) {
10766
+ // Check for portability enumeration extension for MoltenVK support
10767
+ for (const auto & properties : instance_extensions) {
10768
+ if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) {
10769
+ return true;
10770
+ }
10771
+ }
10772
+
10773
+ std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl;
10774
+ return false;
10775
+
10776
+ UNUSED(instance_extensions);
10777
+ }
10778
+
10348
10779
  static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
10349
10780
  switch (props.vendorID) {
10350
10781
  case VK_VENDOR_ID_INTEL:
@@ -10457,11 +10888,21 @@ void * comp_result;
10457
10888
  size_t comp_size;
10458
10889
  size_t comp_nb[GGML_MAX_DIMS];
10459
10890
  size_t check_counter = 0;
10460
- static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10891
+ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
10892
+ ggml_tensor * tensor = cgraph->nodes[tensor_idx];
10461
10893
  if (tensor->op == GGML_OP_TRANSPOSE) {
10462
10894
  return;
10463
10895
  }
10464
10896
 
10897
+ bool fused_rms_norm_mul = false;
10898
+ int rms_norm_idx = -1;
10899
+ if (ctx->num_additional_fused_ops == 1 &&
10900
+ tensor->op == GGML_OP_RMS_NORM &&
10901
+ cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
10902
+ fused_rms_norm_mul = true;
10903
+ tensor = cgraph->nodes[tensor_idx + 1];
10904
+ }
10905
+
10465
10906
  check_counter++;
10466
10907
  if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
10467
10908
  return;
@@ -10489,6 +10930,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10489
10930
 
10490
10931
  for (int i = 0; i < 6; i++) {
10491
10932
  ggml_tensor * srci = tensor->src[i];
10933
+ if (fused_rms_norm_mul) {
10934
+ rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
10935
+ ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
10936
+ switch (i) {
10937
+ case 0: srci = rms_norm->src[0]; break;
10938
+ case 1: srci = tensor->src[1 - rms_norm_idx]; break;
10939
+ default: continue;
10940
+ }
10941
+ }
10492
10942
  if (srci == nullptr) {
10493
10943
  continue;
10494
10944
  }
@@ -10546,7 +10996,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10546
10996
  } else if (tensor->op == GGML_OP_SUB) {
10547
10997
  tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
10548
10998
  } else if (tensor->op == GGML_OP_MUL) {
10549
- tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
10999
+ if (fused_rms_norm_mul) {
11000
+ tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
11001
+ tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
11002
+ } else {
11003
+ tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
11004
+ }
10550
11005
  } else if (tensor->op == GGML_OP_DIV) {
10551
11006
  tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
10552
11007
  } else if (tensor->op == GGML_OP_CONCAT) {
@@ -10634,6 +11089,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10634
11089
  case GGML_UNARY_OP_GELU:
10635
11090
  tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
10636
11091
  break;
11092
+ case GGML_UNARY_OP_GELU_ERF:
11093
+ tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
11094
+ break;
10637
11095
  case GGML_UNARY_OP_GELU_QUICK:
10638
11096
  tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
10639
11097
  break;
@@ -10650,6 +11108,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10650
11108
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
10651
11109
  GGML_ABORT("fatal error");
10652
11110
  }
11111
+ } else if (tensor->op == GGML_OP_GLU) {
11112
+ if (src_clone[1] == nullptr) {
11113
+ tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
11114
+ } else {
11115
+ tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
11116
+ }
10653
11117
  } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
10654
11118
  if (src1 == nullptr) {
10655
11119
  tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
@@ -10657,6 +11121,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10657
11121
  } else {
10658
11122
  tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
10659
11123
  }
11124
+ } else if (tensor->op == GGML_OP_SET_ROWS) {
11125
+ tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
10660
11126
  } else if (tensor->op == GGML_OP_CONT) {
10661
11127
  tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
10662
11128
  } else if (tensor->op == GGML_OP_RESHAPE) {
@@ -10728,10 +11194,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10728
11194
  GGML_ABORT("fatal error");
10729
11195
  }
10730
11196
 
10731
- ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
10732
- ggml_build_forward_expand(cgraph, tensor_clone);
11197
+ ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
11198
+ ggml_build_forward_expand(cgraph_cpu, tensor_clone);
10733
11199
 
10734
- ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
11200
+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
10735
11201
 
10736
11202
  if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
10737
11203
  ggml_vk_print_tensor(tensor_clone, "tensor_clone");
@@ -10754,10 +11220,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10754
11220
  VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
10755
11221
  }
10756
11222
 
10757
- static void ggml_vk_check_results_1(ggml_tensor * tensor) {
11223
+ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
11224
+ ggml_tensor * tensor = cgraph->nodes[tensor_idx];
10758
11225
  if (tensor->op == GGML_OP_TRANSPOSE) {
10759
11226
  return;
10760
11227
  }
11228
+ bool fused_rms_norm_mul = false;
11229
+ if (ctx->num_additional_fused_ops == 1 &&
11230
+ tensor->op == GGML_OP_RMS_NORM &&
11231
+ cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
11232
+ fused_rms_norm_mul = true;
11233
+ tensor = cgraph->nodes[tensor_idx + 1];
11234
+ }
11235
+
10761
11236
  if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
10762
11237
  return;
10763
11238
  }