@novastera-oss/llamarn 0.2.9 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  5. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  17. package/cpp/build-info.cpp +2 -2
  18. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  19. package/cpp/llama.cpp/README.md +4 -5
  20. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  21. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  22. package/cpp/llama.cpp/common/arg.cpp +17 -0
  23. package/cpp/llama.cpp/common/chat.cpp +37 -20
  24. package/cpp/llama.cpp/common/chat.h +2 -0
  25. package/cpp/llama.cpp/common/common.h +4 -0
  26. package/cpp/llama.cpp/convert_hf_to_gguf.py +745 -6
  27. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  28. package/cpp/llama.cpp/ggml/CMakeLists.txt +7 -2
  29. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  30. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +0 -1
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  33. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  35. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1203 -163
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  43. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +17 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  47. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +8 -6
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +185 -79
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  66. package/cpp/llama.cpp/ggml/src/ggml-impl.h +64 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  68. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +35 -9
  69. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +167 -39
  70. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +254 -57
  71. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +505 -40
  73. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +60 -9
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +711 -292
  92. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  93. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  94. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  95. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  105. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  106. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  117. package/cpp/llama.cpp/ggml/src/ggml.c +382 -61
  118. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  119. package/cpp/llama.cpp/gguf-py/gguf/constants.py +209 -0
  120. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  121. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +73 -21
  122. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  123. package/cpp/llama.cpp/include/llama.h +0 -40
  124. package/cpp/llama.cpp/src/llama-arch.cpp +210 -3
  125. package/cpp/llama.cpp/src/llama-arch.h +18 -1
  126. package/cpp/llama.cpp/src/llama-batch.cpp +27 -1
  127. package/cpp/llama.cpp/src/llama-batch.h +8 -1
  128. package/cpp/llama.cpp/src/llama-chat.cpp +15 -0
  129. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  130. package/cpp/llama.cpp/src/llama-graph.cpp +119 -184
  131. package/cpp/llama.cpp/src/llama-graph.h +47 -60
  132. package/cpp/llama.cpp/src/llama-hparams.cpp +7 -1
  133. package/cpp/llama.cpp/src/llama-hparams.h +3 -0
  134. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
  135. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
  136. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
  137. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +62 -24
  138. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  139. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
  140. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  141. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +20 -10
  142. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  143. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  144. package/cpp/llama.cpp/src/llama-model.cpp +2530 -685
  145. package/cpp/llama.cpp/src/llama-model.h +18 -0
  146. package/cpp/llama.cpp/src/llama-quant.cpp +1 -0
  147. package/cpp/llama.cpp/src/llama-vocab.cpp +13 -2
  148. package/cpp/llama.cpp/src/llama-vocab.h +41 -0
  149. package/ios/include/chat.h +2 -0
  150. package/ios/include/common.h +4 -0
  151. package/ios/include/llama.h +0 -40
  152. package/ios/libs/llama.xcframework/Info.plist +19 -19
  153. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  154. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5055 -4886
  155. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  156. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  157. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +0 -40
  158. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  159. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  160. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  161. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  162. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  163. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  164. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  165. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  168. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3766
  169. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  170. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  171. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -40
  172. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  173. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  174. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -40
  175. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  176. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  177. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  178. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -40
  179. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  180. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  181. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  182. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4890
  183. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  184. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  185. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -40
  186. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  187. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  188. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
  189. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
  190. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  191. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  192. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  193. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  194. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  195. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5091 -4922
  196. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  197. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  198. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -40
  199. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  200. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  201. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4897
  202. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3794
  203. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  204. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  205. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
  206. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  207. package/package.json +1 -1
  208. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  209. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  210. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  211. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  212. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  213. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  214. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  215. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  216. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  217. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  218. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  219. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  220. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  221. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  222. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  223. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  224. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  225. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  226. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  227. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  228. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  229. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  230. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  231. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  232. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  233. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  234. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  235. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  236. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  237. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  238. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  239. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  240. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  241. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  242. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  243. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  244. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  245. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  246. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  247. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -224,6 +224,21 @@ enum vk_device_architecture {
224
224
  INTEL_XE2,
225
225
  };
226
226
 
227
+ // HSK x HSV
228
+ enum FaHeadSizes {
229
+ FA_HEAD_SIZE_64,
230
+ FA_HEAD_SIZE_80,
231
+ FA_HEAD_SIZE_96,
232
+ FA_HEAD_SIZE_112,
233
+ FA_HEAD_SIZE_128,
234
+ FA_HEAD_SIZE_192,
235
+ FA_HEAD_SIZE_192_128,
236
+ FA_HEAD_SIZE_256,
237
+ FA_HEAD_SIZE_576_512,
238
+ FA_HEAD_SIZE_UNSUPPORTED,
239
+ FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
240
+ };
241
+
227
242
  static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
228
243
  vk::PhysicalDeviceProperties props = device.getProperties();
229
244
 
@@ -305,7 +320,7 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
305
320
  }
306
321
 
307
322
  struct vk_device_struct {
308
- std::mutex mutex;
323
+ std::recursive_mutex mutex;
309
324
 
310
325
  vk::PhysicalDevice physical_device;
311
326
  vk::PhysicalDeviceProperties properties;
@@ -410,32 +425,42 @@ struct vk_device_struct {
410
425
  vk_pipeline pipeline_div_norepeat[2][2][2];
411
426
 
412
427
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
413
- vk_pipeline pipeline_upscale_f32;
428
+ vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
414
429
  vk_pipeline pipeline_scale_f32;
415
430
  vk_pipeline pipeline_sqr_f32;
416
431
  vk_pipeline pipeline_sin_f32;
417
432
  vk_pipeline pipeline_cos_f32;
418
433
  vk_pipeline pipeline_clamp_f32;
419
434
  vk_pipeline pipeline_pad_f32;
435
+ vk_pipeline pipeline_roll_f32;
420
436
  vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
421
437
  vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
422
438
  vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
423
439
  vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
424
440
  vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
441
+ vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
425
442
  vk_pipeline pipeline_norm_f32;
426
443
  vk_pipeline pipeline_group_norm_f32;
427
444
  vk_pipeline pipeline_rms_norm_f32;
445
+ vk_pipeline pipeline_rms_norm_mul_f32;
428
446
  vk_pipeline pipeline_rms_norm_back_f32;
429
447
  vk_pipeline pipeline_l2_norm_f32;
430
448
 
431
449
  // [src/dst 0=fp32,1=fp16]
432
450
  vk_pipeline pipeline_gelu[2];
451
+ vk_pipeline pipeline_gelu_erf[2];
433
452
  vk_pipeline pipeline_gelu_quick[2];
434
453
  vk_pipeline pipeline_silu[2];
435
454
  vk_pipeline pipeline_relu[2];
436
455
  vk_pipeline pipeline_tanh[2];
437
456
  vk_pipeline pipeline_sigmoid[2];
438
457
 
458
+ vk_pipeline pipeline_geglu[2];
459
+ vk_pipeline pipeline_reglu[2];
460
+ vk_pipeline pipeline_swiglu[2];
461
+ vk_pipeline pipeline_geglu_erf[2];
462
+ vk_pipeline pipeline_geglu_quick[2];
463
+
439
464
  vk_pipeline pipeline_leaky_relu_f32;
440
465
  vk_pipeline pipeline_silu_back_f32;
441
466
  vk_pipeline pipeline_diag_mask_inf_f32;
@@ -461,26 +486,11 @@ struct vk_device_struct {
461
486
  vk_pipeline pipeline_conv2d_dw_cwhn_f32;
462
487
 
463
488
  // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
464
- vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
465
- vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
466
- vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
467
- vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
468
- vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
469
- vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
470
-
471
- vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
472
- vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
473
- vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
474
- vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
475
- vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
476
- vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
477
-
478
- vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
479
- vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
480
- vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
481
- vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
482
- vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
483
- vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
489
+ vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
490
+
491
+ vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
492
+
493
+ vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
484
494
 
485
495
  vk_pipeline pipeline_flash_attn_split_k_reduce;
486
496
 
@@ -493,6 +503,8 @@ struct vk_device_struct {
493
503
 
494
504
  ggml_backend_buffer_type buffer_type;
495
505
 
506
+ bool disable_fusion;
507
+
496
508
  #ifdef GGML_VULKAN_MEMORY_DEBUG
497
509
  std::unique_ptr<vk_memory_logger> memory_logger;
498
510
  #endif
@@ -627,6 +639,8 @@ struct vk_flash_attn_push_constants {
627
639
  uint32_t nev2;
628
640
  uint32_t nev3;
629
641
  uint32_t nem1;
642
+ uint32_t nem2;
643
+ uint32_t nem3;
630
644
 
631
645
  uint32_t nb01;
632
646
  uint32_t nb02;
@@ -637,14 +651,12 @@ struct vk_flash_attn_push_constants {
637
651
  uint32_t nb21;
638
652
  uint32_t nb22;
639
653
  uint32_t nb23;
640
- uint32_t nb31;
641
654
 
642
655
  float scale;
643
656
  float max_bias;
644
657
  float logit_softcap;
645
658
 
646
- uint32_t mask;
647
- uint32_t n_head_log2;
659
+ uint32_t mask_n_head_log2;
648
660
  float m0;
649
661
  float m1;
650
662
 
@@ -652,6 +664,7 @@ struct vk_flash_attn_push_constants {
652
664
  uint32_t split_kv;
653
665
  uint32_t k_num;
654
666
  };
667
+ static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
655
668
 
656
669
  struct vk_op_push_constants {
657
670
  uint32_t KX;
@@ -660,6 +673,13 @@ struct vk_op_push_constants {
660
673
  float param2;
661
674
  };
662
675
 
676
+ struct vk_op_glu_push_constants {
677
+ uint32_t N;
678
+ uint32_t ne00;
679
+ uint32_t ne20;
680
+ uint32_t mode; // 0: default, 1: swapped, 2: split
681
+ };
682
+
663
683
  struct vk_op_unary_push_constants {
664
684
  uint32_t ne;
665
685
  uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -675,6 +695,37 @@ struct vk_op_unary_push_constants {
675
695
  };
676
696
  static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
677
697
 
698
+ static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
699
+ GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
700
+ ne = ne != 0 ? ne : ggml_nelements(dst);
701
+ GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
702
+
703
+ vk_op_unary_push_constants p{};
704
+ p.ne = (uint32_t)ne;
705
+
706
+ size_t src0_tsize = ggml_type_size(src0->type);
707
+ p.ne00 = (uint32_t)src0->ne[0];
708
+ p.ne01 = (uint32_t)src0->ne[1];
709
+ p.ne02 = (uint32_t)src0->ne[2];
710
+ p.ne03 = (uint32_t)src0->ne[3];
711
+ p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
712
+ p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
713
+ p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
714
+ p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
715
+
716
+ size_t dst_tsize = ggml_type_size(dst->type);
717
+ p.ne10 = (uint32_t)dst->ne[0];
718
+ p.ne11 = (uint32_t)dst->ne[1];
719
+ p.ne12 = (uint32_t)dst->ne[2];
720
+ p.ne13 = (uint32_t)dst->ne[3];
721
+ p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
722
+ p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
723
+ p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
724
+ p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
725
+
726
+ return p; // fastdiv values and offsets are initialized later in ggml_vk_op
727
+ }
728
+
678
729
  // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
679
730
  // Precompute mp (m' in the paper) and L such that division
680
731
  // can be computed using a multiply (high 32b of 64b result)
@@ -743,6 +794,14 @@ struct vk_op_rope_push_constants {
743
794
  struct vk_op_soft_max_push_constants {
744
795
  uint32_t KX;
745
796
  uint32_t KY;
797
+ uint32_t ne00;
798
+ uint32_t ne01;
799
+ uint32_t ne02;
800
+ uint32_t ne12;
801
+ uint32_t ne13;
802
+ uint32_t nb11;
803
+ uint32_t nb12;
804
+ uint32_t nb13;
746
805
  float scale;
747
806
  float max_bias;
748
807
  float m0;
@@ -836,6 +895,7 @@ struct vk_op_conv2d_dw_push_constants {
836
895
 
837
896
  struct vk_op_upscale_push_constants {
838
897
  uint32_t ne; uint32_t a_offset; uint32_t d_offset;
898
+ uint32_t ne00; uint32_t ne01;
839
899
  uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
840
900
  uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
841
901
  float sf0; float sf1; float sf2; float sf3;
@@ -978,6 +1038,10 @@ struct ggml_backend_vk_context {
978
1038
 
979
1039
  vk_command_pool compute_cmd_pool;
980
1040
  vk_command_pool transfer_cmd_pool;
1041
+
1042
+ // number of additional consecutive nodes that are being fused with the
1043
+ // node currently being processed
1044
+ int num_additional_fused_ops {};
981
1045
  };
982
1046
 
983
1047
  static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@@ -1063,8 +1127,8 @@ static size_t vk_skip_checks;
1063
1127
  static size_t vk_output_tensor;
1064
1128
 
1065
1129
  static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
1066
- static void ggml_vk_check_results_0(ggml_tensor * tensor);
1067
- static void ggml_vk_check_results_1(ggml_tensor * tensor);
1130
+ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
1131
+ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
1068
1132
  #endif
1069
1133
 
1070
1134
  typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@@ -1197,7 +1261,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
1197
1261
  }
1198
1262
 
1199
1263
  {
1200
- std::lock_guard<std::mutex> guard(device->mutex);
1264
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
1201
1265
  device->pipelines.insert({ pipeline->name, pipeline });
1202
1266
  }
1203
1267
 
@@ -1411,7 +1475,7 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyPrope
1411
1475
 
1412
1476
  static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
1413
1477
  VK_LOG_DEBUG("ggml_vk_create_queue()");
1414
- std::lock_guard<std::mutex> guard(device->mutex);
1478
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
1415
1479
 
1416
1480
  q.queue_family_index = queue_family_index;
1417
1481
  q.transfer_only = transfer_only;
@@ -1673,10 +1737,46 @@ enum FaCodePath {
1673
1737
  FA_COOPMAT2,
1674
1738
  };
1675
1739
 
1740
+ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
1741
+ if (hsk != 192 && hsk != 576 && hsk != hsv) {
1742
+ return FA_HEAD_SIZE_UNSUPPORTED;
1743
+ }
1744
+ switch (hsk) {
1745
+ case 64: return FA_HEAD_SIZE_64;
1746
+ case 80: return FA_HEAD_SIZE_80;
1747
+ case 96: return FA_HEAD_SIZE_96;
1748
+ case 112: return FA_HEAD_SIZE_112;
1749
+ case 128: return FA_HEAD_SIZE_128;
1750
+ case 192:
1751
+ if (hsv == 192) {
1752
+ return FA_HEAD_SIZE_192;
1753
+ } else if (hsv == 128) {
1754
+ return FA_HEAD_SIZE_192_128;
1755
+ } else {
1756
+ return FA_HEAD_SIZE_UNSUPPORTED;
1757
+ }
1758
+ case 256: return FA_HEAD_SIZE_256;
1759
+ case 576:
1760
+ if (hsv == 512) {
1761
+ return FA_HEAD_SIZE_576_512;
1762
+ } else {
1763
+ return FA_HEAD_SIZE_UNSUPPORTED;
1764
+ }
1765
+ default: return FA_HEAD_SIZE_UNSUPPORTED;
1766
+ }
1767
+ }
1768
+
1676
1769
  // number of rows/cols for flash attention shader
1677
1770
  static constexpr uint32_t flash_attention_num_small_rows = 32;
1678
1771
  static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
1679
- static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
1772
+
1773
+ static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
1774
+ if (hsv >= 512) {
1775
+ return 2;
1776
+ } else {
1777
+ return 8;
1778
+ }
1779
+ }
1680
1780
 
1681
1781
  // The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
1682
1782
  // 128 threads split into four subgroups, each subgroup does 1/4
@@ -1693,14 +1793,15 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
1693
1793
  }
1694
1794
  }
1695
1795
 
1696
- static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
1796
+ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
1697
1797
  GGML_UNUSED(clamp);
1798
+ GGML_UNUSED(hsv);
1698
1799
 
1699
1800
  if (path == FA_SCALAR) {
1700
1801
  if (small_rows) {
1701
1802
  return {scalar_flash_attention_num_small_rows, 64};
1702
1803
  } else {
1703
- return {scalar_flash_attention_num_large_rows, 32};
1804
+ return {get_fa_scalar_num_large_rows(hsv), 32};
1704
1805
  }
1705
1806
  }
1706
1807
 
@@ -1718,8 +1819,12 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
1718
1819
  }
1719
1820
 
1720
1821
  // small cols to reduce register count
1721
- if (ggml_is_quantized(type) || D == 256) {
1722
- return {64, 32};
1822
+ if (ggml_is_quantized(type) || hsk >= 256) {
1823
+ if (hsk >= 512) {
1824
+ return {32, 32};
1825
+ } else {
1826
+ return {64, 32};
1827
+ }
1723
1828
  }
1724
1829
  return {64, 64};
1725
1830
  }
@@ -1761,7 +1866,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1761
1866
  const uint32_t warps = warptile[0] / warptile[10];
1762
1867
 
1763
1868
  const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
1764
- const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
1869
+ const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
1765
1870
  const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
1766
1871
 
1767
1872
  const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
@@ -1886,10 +1991,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
1886
1991
  s_mmq_wg_denoms_k = { 32, 32, 1 };
1887
1992
 
1888
1993
  // spec constants and tile sizes for quant matmul_id
1889
- l_warptile_mmqid = { 256, 128, 64, 16, 0 };
1994
+ l_warptile_mmqid = { 256, 128, 128, 16, 0 };
1890
1995
  m_warptile_mmqid = { 256, 128, 64, 16, 0 };
1891
1996
  s_warptile_mmqid = { 256, 128, 64, 16, 0 };
1892
- l_mmqid_wg_denoms = { 128, 64, 1 };
1997
+ l_mmqid_wg_denoms = { 128, 128, 1 };
1893
1998
  m_mmqid_wg_denoms = { 128, 64, 1 };
1894
1999
  s_mmqid_wg_denoms = { 128, 64, 1 };
1895
2000
 
@@ -2011,19 +2116,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
2011
2116
  parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
2012
2117
  };
2013
2118
 
2014
- auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
2015
- return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
2119
+ auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
2120
+ return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
2016
2121
  };
2017
2122
 
2018
- auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
2123
+ auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
2019
2124
  // For large number of rows, 128 invocations seems to work best.
2020
2125
  // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
2021
2126
  // can't use 256 for D==80.
2022
2127
  // For scalar, use 128 (arbitrary)
2128
+ // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
2129
+ const uint32_t D = (hsk|hsv);
2023
2130
  uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
2024
2131
  ? scalar_flash_attention_workgroup_size
2025
2132
  : ((small_rows && (D % 32) == 0) ? 256 : 128);
2026
- auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
2133
+ auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
2027
2134
 
2028
2135
  // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
2029
2136
  // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@@ -2032,26 +2139,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
2032
2139
 
2033
2140
  // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
2034
2141
  GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
2035
- return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
2142
+ return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
2036
2143
  };
2037
2144
 
2038
- #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
2039
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2040
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2041
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2042
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2043
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2044
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2045
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2046
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2145
+ #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
2146
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2147
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2148
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2149
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2150
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2151
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2152
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2153
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
2047
2154
 
2048
2155
  #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
2049
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
2050
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
2051
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
2052
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
2053
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
2054
- CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
2156
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
2157
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
2158
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
2159
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
2160
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
2161
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
2162
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
2163
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
2164
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
2055
2165
 
2056
2166
  CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
2057
2167
  CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
@@ -2641,7 +2751,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2641
2751
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2642
2752
 
2643
2753
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2644
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
2754
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
2645
2755
  ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2646
2756
 
2647
2757
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -2655,7 +2765,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2655
2765
 
2656
2766
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2657
2767
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2658
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2768
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2769
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
2659
2770
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2660
2771
  ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2661
2772
 
@@ -2672,19 +2783,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
2672
2783
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2673
2784
 
2674
2785
  if (device->float_controls_rte_fp16) {
2675
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2676
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2677
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2678
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2679
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2680
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2786
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2787
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2788
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2789
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2790
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2791
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2681
2792
  } else {
2682
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2683
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2684
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2685
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2686
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2687
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2793
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2794
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2795
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2796
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2797
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2798
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2799
+ }
2800
+
2801
+ if (device->float_controls_rte_fp16) {
2802
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2803
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2804
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2805
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2806
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2807
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2808
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2809
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2810
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2811
+ } else {
2812
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2813
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2814
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2815
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2816
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2817
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2818
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2819
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2820
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2688
2821
  }
2689
2822
 
2690
2823
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
@@ -2724,7 +2857,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
2724
2857
  ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2725
2858
  ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2726
2859
 
2727
- ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
2860
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
2861
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
2862
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
2728
2863
 
2729
2864
  ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2730
2865
 
@@ -2736,6 +2871,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2736
2871
 
2737
2872
  ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2738
2873
 
2874
+ ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2875
+
2739
2876
  ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2740
2877
  ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2741
2878
 
@@ -2744,6 +2881,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2744
2881
  ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2745
2882
 
2746
2883
  CREATE_UNARY(gelu)
2884
+ CREATE_UNARY(gelu_erf)
2747
2885
  CREATE_UNARY(gelu_quick)
2748
2886
  CREATE_UNARY(silu)
2749
2887
  CREATE_UNARY(relu)
@@ -2751,6 +2889,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
2751
2889
  CREATE_UNARY(sigmoid)
2752
2890
  #undef CREATE_UNARY
2753
2891
 
2892
+ #define CREATE_GLU(name) \
2893
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2894
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
2895
+
2896
+ CREATE_GLU(geglu)
2897
+ CREATE_GLU(reglu)
2898
+ CREATE_GLU(swiglu)
2899
+ CREATE_GLU(geglu_erf)
2900
+ CREATE_GLU(geglu_quick)
2901
+ #undef CREATE_GLU
2902
+
2754
2903
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2755
2904
  ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2756
2905
 
@@ -3431,6 +3580,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
3431
3580
 
3432
3581
  device->idx = idx;
3433
3582
 
3583
+ device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
3584
+
3434
3585
  return device;
3435
3586
  }
3436
3587
 
@@ -3651,7 +3802,6 @@ static void ggml_vk_instance_init() {
3651
3802
 
3652
3803
  }
3653
3804
 
3654
- size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
3655
3805
  vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
3656
3806
 
3657
3807
  // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
@@ -4124,6 +4274,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
4124
4274
  return nullptr;
4125
4275
  }
4126
4276
 
4277
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
4127
4278
  device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
4128
4279
 
4129
4280
  return buf->ptr;
@@ -4134,6 +4285,8 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
4134
4285
  return;
4135
4286
  }
4136
4287
  VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
4288
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
4289
+
4137
4290
  vk_buffer buf;
4138
4291
  size_t index;
4139
4292
  for (size_t i = 0; i < device->pinned_memory.size(); i++) {
@@ -4156,6 +4309,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
4156
4309
  }
4157
4310
 
4158
4311
  static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
4312
+ std::lock_guard<std::recursive_mutex> guard(device->mutex);
4159
4313
  buf = nullptr;
4160
4314
  buf_offset = 0;
4161
4315
  for (size_t i = 0; i < device->pinned_memory.size(); i++) {
@@ -4457,7 +4611,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
4457
4611
  memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
4458
4612
  }
4459
4613
  } else {
4460
- std::lock_guard<std::mutex> guard(dst->device->mutex);
4614
+ std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
4461
4615
 
4462
4616
  vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
4463
4617
  ggml_vk_ctx_begin(dst->device, subctx);
@@ -4548,7 +4702,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
4548
4702
 
4549
4703
  memcpy(dst, (uint8_t *) src->ptr + offset, size);
4550
4704
  } else {
4551
- std::lock_guard<std::mutex> guard(src->device->mutex);
4705
+ std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
4552
4706
 
4553
4707
  vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
4554
4708
  ggml_vk_ctx_begin(src->device, subctx);
@@ -4578,7 +4732,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds
4578
4732
 
4579
4733
  static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
4580
4734
  if (src->device == dst->device) {
4581
- std::lock_guard<std::mutex> guard(src->device->mutex);
4735
+ std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
4582
4736
  VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
4583
4737
  // Copy within the device
4584
4738
  vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
@@ -4613,7 +4767,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t
4613
4767
  static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
4614
4768
  VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
4615
4769
 
4616
- std::lock_guard<std::mutex> guard(dst->device->mutex);
4770
+ std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
4617
4771
  vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
4618
4772
  ggml_vk_ctx_begin(dst->device, subctx);
4619
4773
  subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
@@ -4840,9 +4994,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
4840
4994
  // type size must be exactly 2 or 4.
4841
4995
  GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
4842
4996
  if ((ggml_type_size(src->type) % 4) == 0) {
4843
- return ctx->device->pipeline_contig_cpy_f32_f32;
4997
+ if (contig) {
4998
+ return ctx->device->pipeline_contig_cpy_f32_f32;
4999
+ } else {
5000
+ return ctx->device->pipeline_cpy_f32_f32;
5001
+ }
4844
5002
  } else {
4845
- return ctx->device->pipeline_contig_cpy_f16_f16;
5003
+ if (contig) {
5004
+ return ctx->device->pipeline_contig_cpy_f16_f16;
5005
+ } else {
5006
+ return ctx->device->pipeline_cpy_f16_f16;
5007
+ }
4846
5008
  }
4847
5009
  }
4848
5010
 
@@ -4903,7 +5065,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4903
5065
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
4904
5066
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
4905
5067
  std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
4906
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
5068
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
4907
5069
  GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
4908
5070
 
4909
5071
  const uint64_t ne00 = src0->ne[0];
@@ -5131,7 +5293,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
5131
5293
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
5132
5294
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
5133
5295
  std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)");
5134
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
5296
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
5135
5297
  GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
5136
5298
 
5137
5299
  const uint64_t ne00 = src0->ne[0];
@@ -5732,7 +5894,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
5732
5894
  std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
5733
5895
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
5734
5896
  std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
5735
- GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
5897
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
5736
5898
  GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
5737
5899
  GGML_ASSERT(ids->type == GGML_TYPE_I32);
5738
5900
 
@@ -5926,14 +6088,60 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
5926
6088
  if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
5927
6089
  ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
5928
6090
  } else {
5929
- ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
6091
+ // Split based on number of ids, to fit in shared memory
6092
+ const uint32_t nei0 = (uint32_t)src2->ne[0];
6093
+ const uint32_t nei1 = (uint32_t)src2->ne[1];
6094
+
6095
+ GGML_ASSERT(nei0 <= 4096);
6096
+ const uint32_t split_size = std::min(nei1, 4096u / nei0);
6097
+
6098
+ ggml_tensor src1_copy = *src1;
6099
+ ggml_tensor src2_copy = *src2;
6100
+ ggml_tensor dst_copy = *dst;
6101
+
6102
+ for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
6103
+ const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
6104
+
6105
+ src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
6106
+ src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
6107
+ dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
6108
+
6109
+ src1_copy.ne[2] = n_tokens;
6110
+ src2_copy.ne[1] = n_tokens;
6111
+ dst_copy.ne[2] = n_tokens;
6112
+
6113
+ ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
6114
+ }
5930
6115
  }
5931
6116
  }
5932
6117
 
5933
- static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
6118
+ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
6119
+ // Needs to be kept up to date on shader changes
6120
+ GGML_UNUSED(hsv);
6121
+ const uint32_t wg_size = scalar_flash_attention_workgroup_size;
6122
+ const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
6123
+ const uint32_t Bc = scalar_flash_attention_Bc;
6124
+
6125
+ const uint32_t tmpsh = wg_size * sizeof(float);
6126
+ const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
6127
+
6128
+ const uint32_t masksh = Bc * Br * sizeof(float);
6129
+
6130
+ const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
6131
+
6132
+ const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
6133
+ const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
6134
+
6135
+ VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
6136
+
6137
+ return supported;
6138
+ }
6139
+
6140
+ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
5934
6141
  // Needs to be kept up to date on shader changes
6142
+ GGML_UNUSED(hsv);
5935
6143
  const uint32_t wg_size = scalar_flash_attention_workgroup_size;
5936
- const uint32_t Br = scalar_flash_attention_num_large_rows;
6144
+ const uint32_t Br = coopmat1_flash_attention_num_large_rows;
5937
6145
  const uint32_t Bc = scalar_flash_attention_Bc;
5938
6146
 
5939
6147
  const uint32_t acctype = f32acc ? 4 : 2;
@@ -5942,12 +6150,12 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
5942
6150
  const uint32_t tmpsh = wg_size * sizeof(float);
5943
6151
  const uint32_t tmpshv4 = wg_size * 4 * acctype;
5944
6152
 
5945
- const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
6153
+ const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
5946
6154
 
5947
- const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
6155
+ const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
5948
6156
  const uint32_t sfsh = Bc * sfshstride * acctype;
5949
6157
 
5950
- const uint32_t kshstride = D / 4 + 2;
6158
+ const uint32_t kshstride = hsk / 4 + 2;
5951
6159
  const uint32_t ksh = Bc * kshstride * f16vec4;
5952
6160
 
5953
6161
  const uint32_t slope = Br * sizeof(float);
@@ -5955,7 +6163,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
5955
6163
  const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
5956
6164
  const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
5957
6165
 
5958
- VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
6166
+ VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
5959
6167
 
5960
6168
  return supported;
5961
6169
  }
@@ -5977,13 +6185,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5977
6185
  GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
5978
6186
 
5979
6187
  const uint32_t nem1 = mask ? mask->ne[1] : 0;
5980
- const uint32_t nbm1 = mask ? mask->nb[1] : 0;
6188
+ const uint32_t nem2 = mask ? mask->ne[2] : 0;
6189
+ const uint32_t nem3 = mask ? mask->ne[3] : 0;
5981
6190
 
5982
- const uint32_t D = neq0;
6191
+ const uint32_t HSK = nek0;
6192
+ const uint32_t HSV = nev0;
5983
6193
  uint32_t N = neq1;
5984
6194
  const uint32_t KV = nek1;
5985
6195
 
5986
- GGML_ASSERT(ne0 == D);
6196
+ GGML_ASSERT(ne0 == HSV);
5987
6197
  GGML_ASSERT(ne2 == N);
5988
6198
 
5989
6199
  // input tensor rows must be contiguous
@@ -5991,12 +6201,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5991
6201
  GGML_ASSERT(nbk0 == ggml_type_size(k->type));
5992
6202
  GGML_ASSERT(nbv0 == ggml_type_size(v->type));
5993
6203
 
5994
- GGML_ASSERT(neq0 == D);
5995
- GGML_ASSERT(nek0 == D);
5996
- GGML_ASSERT(nev0 == D);
6204
+ GGML_ASSERT(neq0 == HSK);
5997
6205
 
5998
6206
  GGML_ASSERT(neq1 == N);
5999
- GGML_ASSERT(nev0 == D);
6000
6207
 
6001
6208
  GGML_ASSERT(nev1 == nek1);
6002
6209
 
@@ -6017,7 +6224,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6017
6224
  const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
6018
6225
  (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
6019
6226
 
6020
- const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
6227
+ const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
6021
6228
 
6022
6229
  if (!coopmat_shape_supported || !coopmat_shmem_supported) {
6023
6230
  path = FA_SCALAR;
@@ -6037,7 +6244,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6037
6244
  case FA_SCALAR:
6038
6245
  case FA_COOPMAT1:
6039
6246
  // We may switch from coopmat1 to scalar, so use the scalar limit for both
6040
- max_gqa = scalar_flash_attention_num_large_rows;
6247
+ max_gqa = get_fa_scalar_num_large_rows(HSV);
6041
6248
  break;
6042
6249
  case FA_COOPMAT2:
6043
6250
  max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
@@ -6047,7 +6254,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6047
6254
  }
6048
6255
 
6049
6256
  if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
6050
- qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
6257
+ qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
6051
6258
  // grouped query attention - make the N dimension equal to gqa_ratio, reduce
6052
6259
  // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
6053
6260
  // and change addressing calculations to index Q's dimension 2.
@@ -6070,47 +6277,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6070
6277
  path = FA_SCALAR;
6071
6278
  }
6072
6279
 
6280
+ // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
6281
+ if (path == FA_SCALAR &&
6282
+ !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
6283
+ small_rows = true;
6284
+ }
6285
+
6073
6286
  bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
6074
6287
 
6288
+ FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
6289
+
6075
6290
  switch (path) {
6076
6291
  case FA_SCALAR:
6077
- switch (D) {
6078
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
6079
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
6080
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
6081
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
6082
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
6083
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
6084
- default:
6085
- GGML_ASSERT(!"unsupported D value");
6086
- return;
6087
- }
6292
+ pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
6088
6293
  break;
6089
6294
  case FA_COOPMAT1:
6090
- switch (D) {
6091
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
6092
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
6093
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
6094
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
6095
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
6096
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
6097
- default:
6098
- GGML_ASSERT(!"unsupported D value");
6099
- return;
6100
- }
6295
+ pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
6101
6296
  break;
6102
6297
  case FA_COOPMAT2:
6103
- switch (D) {
6104
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
6105
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
6106
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
6107
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
6108
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
6109
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
6110
- default:
6111
- GGML_ASSERT(!"unsupported D value");
6112
- return;
6113
- }
6298
+ pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
6114
6299
  break;
6115
6300
  default:
6116
6301
  GGML_ASSERT(0);
@@ -6138,21 +6323,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6138
6323
  const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
6139
6324
 
6140
6325
  // Try to use split_k when KV is large enough to be worth the overhead
6141
- if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
6326
+ if (workgroups_x == 1 && shader_core_count > 0) {
6142
6327
  // Try to run two workgroups per SM.
6143
- split_k = ctx->device->shader_core_count * 2 / workgroups_y;
6328
+ split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
6144
6329
  if (split_k > 1) {
6145
6330
  // Try to evenly split KV into split_k chunks, but it needs to be a multiple
6146
6331
  // of "align", so recompute split_k based on that.
6147
- split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
6332
+ split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
6148
6333
  split_k = CEIL_DIV(KV, split_kv);
6149
6334
  workgroups_x = split_k;
6150
6335
  }
6151
6336
  }
6152
6337
 
6153
- // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
6154
- // and the per-row m and L values (ne1 rows).
6155
- const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
6338
+ // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
6339
+ // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
6340
+ const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
6156
6341
  if (split_k_size > ctx->device->max_memory_allocation_size) {
6157
6342
  GGML_ABORT("Requested preallocation size is too large");
6158
6343
  }
@@ -6239,18 +6424,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6239
6424
  }
6240
6425
  }
6241
6426
 
6427
+ uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
6428
+
6242
6429
  const vk_flash_attn_push_constants pc = { N, KV,
6243
6430
  (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
6244
6431
  (uint32_t)neq2, (uint32_t)neq3,
6245
6432
  (uint32_t)nek2, (uint32_t)nek3,
6246
6433
  (uint32_t)nev2, (uint32_t)nev3,
6247
- nem1,
6434
+ nem1, nem2, nem3,
6248
6435
  q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
6249
6436
  k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
6250
6437
  v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
6251
- nbm1,
6252
6438
  scale, max_bias, logit_softcap,
6253
- mask != nullptr, n_head_log2, m0, m1,
6439
+ mask_n_head_log2, m0, m1,
6254
6440
  gqa_ratio, split_kv, split_k };
6255
6441
 
6256
6442
  ggml_vk_sync_buffers(subctx);
@@ -6271,13 +6457,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6271
6457
  pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
6272
6458
 
6273
6459
  ggml_vk_sync_buffers(subctx);
6274
- const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
6460
+ const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
6275
6461
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
6276
6462
  {
6277
6463
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6278
6464
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6279
6465
  },
6280
- pc2, { (uint32_t)ne1, 1, 1 });
6466
+ pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
6281
6467
  } else {
6282
6468
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6283
6469
  {
@@ -6353,8 +6539,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6353
6539
  }
6354
6540
  return nullptr;
6355
6541
  case GGML_OP_UPSCALE:
6356
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
6357
- return ctx->device->pipeline_upscale_f32;
6542
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6543
+ int mode = ggml_get_op_params_i32(dst, 0);
6544
+ switch (mode) {
6545
+ case GGML_SCALE_MODE_NEAREST:
6546
+ return ctx->device->pipeline_upscale_nearest_f32;
6547
+ case GGML_SCALE_MODE_BILINEAR:
6548
+ return ctx->device->pipeline_upscale_bilinear_f32;
6549
+ case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
6550
+ return ctx->device->pipeline_upscale_bilinear_ac_f32;
6551
+ }
6358
6552
  }
6359
6553
  return nullptr;
6360
6554
  case GGML_OP_SCALE:
@@ -6387,6 +6581,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6387
6581
  return ctx->device->pipeline_pad_f32;
6388
6582
  }
6389
6583
  return nullptr;
6584
+ case GGML_OP_ROLL:
6585
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6586
+ return ctx->device->pipeline_roll_f32;
6587
+ }
6588
+ return nullptr;
6390
6589
  case GGML_OP_REPEAT:
6391
6590
  if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
6392
6591
  return ctx->device->pipeline_repeat_f32;
@@ -6401,6 +6600,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6401
6600
  case GGML_OP_CONT:
6402
6601
  case GGML_OP_DUP:
6403
6602
  return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
6603
+ case GGML_OP_SET_ROWS:
6604
+ return ctx->device->pipeline_set_rows[dst->type];
6404
6605
  case GGML_OP_SILU_BACK:
6405
6606
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6406
6607
  return ctx->device->pipeline_silu_back_f32;
@@ -6418,7 +6619,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6418
6619
  return nullptr;
6419
6620
  case GGML_OP_RMS_NORM:
6420
6621
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6421
- return ctx->device->pipeline_rms_norm_f32;
6622
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
6422
6623
  }
6423
6624
  return nullptr;
6424
6625
  case GGML_OP_RMS_NORM_BACK:
@@ -6443,6 +6644,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6443
6644
  return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
6444
6645
  case GGML_UNARY_OP_GELU:
6445
6646
  return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
6647
+ case GGML_UNARY_OP_GELU_ERF:
6648
+ return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
6446
6649
  case GGML_UNARY_OP_GELU_QUICK:
6447
6650
  return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
6448
6651
  case GGML_UNARY_OP_RELU:
@@ -6455,6 +6658,28 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6455
6658
  break;
6456
6659
  }
6457
6660
  return nullptr;
6661
+ case GGML_OP_GLU:
6662
+ if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6663
+ (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
6664
+ (src0->type != dst->type)) {
6665
+ return nullptr;
6666
+ }
6667
+
6668
+ switch (ggml_get_glu_op(dst)) {
6669
+ case GGML_GLU_OP_GEGLU:
6670
+ return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
6671
+ case GGML_GLU_OP_REGLU:
6672
+ return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
6673
+ case GGML_GLU_OP_SWIGLU:
6674
+ return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
6675
+ case GGML_GLU_OP_GEGLU_ERF:
6676
+ return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
6677
+ case GGML_GLU_OP_GEGLU_QUICK:
6678
+ return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
6679
+ default:
6680
+ break;
6681
+ }
6682
+ return nullptr;
6458
6683
  case GGML_OP_DIAG_MASK_INF:
6459
6684
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6460
6685
  return ctx->device->pipeline_diag_mask_inf_f32;
@@ -6615,6 +6840,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
6615
6840
  case GGML_OP_RMS_NORM:
6616
6841
  case GGML_OP_CONV_2D_DW:
6617
6842
  case GGML_OP_IM2COL:
6843
+ case GGML_OP_SET_ROWS:
6618
6844
  return true;
6619
6845
  default:
6620
6846
  return false;
@@ -6909,12 +7135,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6909
7135
  case GGML_OP_COS:
6910
7136
  case GGML_OP_CLAMP:
6911
7137
  case GGML_OP_PAD:
7138
+ case GGML_OP_ROLL:
6912
7139
  case GGML_OP_REPEAT:
6913
7140
  case GGML_OP_REPEAT_BACK:
6914
7141
  case GGML_OP_CPY:
6915
7142
  case GGML_OP_CONCAT:
6916
7143
  case GGML_OP_UPSCALE:
6917
7144
  case GGML_OP_UNARY:
7145
+ case GGML_OP_GLU:
6918
7146
  case GGML_OP_CONV_2D_DW:
6919
7147
  {
6920
7148
  uint32_t ne = ggml_nelements(dst);
@@ -6927,6 +7155,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6927
7155
  ne *= ggml_type_size(src0->type) / 2;
6928
7156
  }
6929
7157
  }
7158
+ // copy_to_quant has block size of 32, and each thread does QUANT_K elements.
7159
+ // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
7160
+ // So divide by block size here before splitting into 512x512 groups.
7161
+ if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
7162
+ ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
7163
+ }
6930
7164
  if (ne > 262144) {
6931
7165
  elements = { 512, 512, CEIL_DIV(ne, 262144) };
6932
7166
  } else if (ne > 512) {
@@ -6935,6 +7169,25 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6935
7169
  elements = { ne, 1, 1 };
6936
7170
  }
6937
7171
  } break;
7172
+ case GGML_OP_SET_ROWS:
7173
+ {
7174
+ uint32_t ne = ggml_nelements(src0);
7175
+ if (ggml_is_quantized(dst->type)) {
7176
+ // quants run 32 threads each doing QUANT_K elements
7177
+ ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
7178
+ } else {
7179
+ // scalar types do one element per thread, running 512 threads
7180
+ ne = CEIL_DIV(ne, 512);
7181
+ }
7182
+ if (ne > 262144) {
7183
+ elements = { 512, 512, CEIL_DIV(ne, 262144) };
7184
+ } else if (ne > 512) {
7185
+ elements = { 512, CEIL_DIV(ne, 512), 1 };
7186
+ } else {
7187
+ elements = { ne, 1, 1 };
7188
+ }
7189
+ }
7190
+ break;
6938
7191
  default:
6939
7192
  elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
6940
7193
  break;
@@ -6955,7 +7208,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6955
7208
  }
6956
7209
  }
6957
7210
 
6958
- if (op == GGML_OP_SOFT_MAX) {
7211
+ if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
6959
7212
  // Empty src1 is possible in soft_max, but the shader needs a buffer
6960
7213
  vk_subbuffer subbuf_y;
6961
7214
  if (use_src1) {
@@ -7344,14 +7597,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
7344
7597
 
7345
7598
  static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7346
7599
  const uint32_t src0_type_size = ggml_type_size(src0->type);
7600
+ const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
7601
+
7602
+ float sf0 = (float)dst->ne[0] / src0->ne[0];
7603
+ float sf1 = (float)dst->ne[1] / src0->ne[1];
7604
+ float sf2 = (float)dst->ne[2] / src0->ne[2];
7605
+ float sf3 = (float)dst->ne[3] / src0->ne[3];
7347
7606
 
7348
- const float sf0 = (float)dst->ne[0] / src0->ne[0];
7349
- const float sf1 = (float)dst->ne[1] / src0->ne[1];
7350
- const float sf2 = (float)dst->ne[2] / src0->ne[2];
7351
- const float sf3 = (float)dst->ne[3] / src0->ne[3];
7607
+ if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7608
+ sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
7609
+ sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
7610
+ }
7352
7611
 
7353
7612
  ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
7354
7613
  (uint32_t)ggml_nelements(dst), 0, 0,
7614
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
7355
7615
  (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7356
7616
  (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
7357
7617
  sf0, sf1, sf2, sf3,
@@ -7359,123 +7619,64 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
7359
7619
  }
7360
7620
 
7361
7621
  static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7362
- float * op_params = (float *)dst->op_params;
7363
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7364
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7622
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7623
+ p.param1 = ggml_get_op_params_f32(dst, 0);
7624
+ p.param2 = ggml_get_op_params_f32(dst, 1);
7365
7625
 
7366
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
7367
- (uint32_t)ggml_nelements(src0),
7368
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7369
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7370
- 0,
7371
- op_params[0], 0.0f,
7372
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7373
- }, dryrun);
7626
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
7374
7627
  }
7375
7628
 
7376
7629
  static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7377
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7378
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7379
-
7380
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
7381
- (uint32_t)ggml_nelements(src0),
7382
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7383
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7384
- 0,
7385
- 0.0f, 0.0f,
7386
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7387
- }, dryrun);
7630
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
7388
7631
  }
7389
7632
 
7390
7633
  static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7391
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7392
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7393
-
7394
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
7395
- (uint32_t)ggml_nelements(src0),
7396
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7397
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7398
- 0,
7399
- 0.0f, 0.0f,
7400
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7401
- }, dryrun);
7634
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
7402
7635
  }
7403
7636
 
7404
7637
  static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7405
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7406
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7407
-
7408
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
7409
- (uint32_t)ggml_nelements(src0),
7410
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7411
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7412
- 0,
7413
- 0.0f, 0.0f,
7414
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7415
- }, dryrun);
7638
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
7416
7639
  }
7417
7640
 
7418
7641
  static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7419
- float * op_params = (float *)dst->op_params;
7420
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7421
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7642
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7643
+ p.param1 = ggml_get_op_params_f32(dst, 0);
7644
+ p.param2 = ggml_get_op_params_f32(dst, 1);
7422
7645
 
7423
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
7424
- (uint32_t)ggml_nelements(src0),
7425
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7426
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7427
- 0,
7428
- op_params[0], op_params[1],
7429
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7430
- }, dryrun);
7646
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
7431
7647
  }
7432
7648
 
7433
7649
  static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7434
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7435
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7650
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7651
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
7652
+ }
7436
7653
 
7437
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
7438
- (uint32_t)ggml_nelements(dst),
7439
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7440
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7441
- 0,
7442
- 0.0f, 0.0f,
7443
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7444
- }, dryrun);
7654
+ static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7655
+ const int32_t s0 = ggml_get_op_params_i32(dst, 0);
7656
+ const int32_t s1 = ggml_get_op_params_i32(dst, 1);
7657
+ const int32_t s2 = ggml_get_op_params_i32(dst, 2);
7658
+ const int32_t s3 = ggml_get_op_params_i32(dst, 3);
7659
+ const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
7660
+ const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
7661
+
7662
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7663
+ memcpy(&p.param1, &s01_packed, sizeof(float));
7664
+ memcpy(&p.param2, &s23_packed, sizeof(float));
7665
+
7666
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
7445
7667
  }
7446
7668
 
7447
7669
  static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7448
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7449
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7450
-
7451
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
7452
- (uint32_t)ggml_nelements(dst),
7453
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7454
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7455
- 0,
7456
- 0.0f, 0.0f,
7457
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7458
- }, dryrun);
7670
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7671
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
7459
7672
  }
7460
7673
 
7461
7674
  static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7462
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7463
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7464
-
7465
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
7466
- (uint32_t)ggml_nelements(dst),
7467
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7468
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7469
- 0,
7470
- 0.0f, 0.0f,
7471
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7472
- }, dryrun);
7675
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7676
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
7473
7677
  }
7474
7678
 
7475
7679
  static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7476
- const uint32_t src0_type_size = ggml_type_size(src0->type);
7477
- const uint32_t dst_type_size = ggml_type_size(dst->type);
7478
-
7479
7680
  uint32_t ne = (uint32_t)ggml_nelements(src0);
7480
7681
  if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
7481
7682
  // Convert from number of logical elements to 2- or 4-byte units.
@@ -7487,13 +7688,22 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
7487
7688
  }
7488
7689
  }
7489
7690
 
7490
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
7491
- ne,
7492
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7493
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7691
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
7692
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
7693
+ }
7694
+
7695
+ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7696
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7697
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
7698
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7699
+
7700
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
7701
+ (uint32_t)ggml_nelements(src0),
7702
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7703
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7704
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7494
7705
  0,
7495
- 0.0f, 0.0f,
7496
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7706
+ 0.0f, 0.0f, 0,
7497
7707
  }, dryrun);
7498
7708
  }
7499
7709
 
@@ -7518,18 +7728,18 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
7518
7728
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
7519
7729
  }
7520
7730
 
7521
- static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7522
- float * op_params = (float *)dst->op_params;
7731
+ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
7523
7732
  const uint32_t src0_type_size = ggml_type_size(src0->type);
7733
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
7524
7734
  const uint32_t dst_type_size = ggml_type_size(dst->type);
7525
7735
 
7526
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
7736
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
7527
7737
  (uint32_t)ggml_nelements(src0),
7528
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7529
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7738
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7739
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7740
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7530
7741
  0,
7531
- op_params[0], 0.0f,
7532
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7742
+ op_params[0], 0.0f, 0,
7533
7743
  }, dryrun);
7534
7744
  }
7535
7745
 
@@ -7547,6 +7757,25 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
7547
7757
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7548
7758
  }
7549
7759
 
7760
+ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7761
+ const bool swapped = (bool)dst->op_params[1];
7762
+ const bool split = src1 != nullptr;
7763
+
7764
+ GGML_ASSERT(ggml_is_contiguous(src0));
7765
+
7766
+ if (!split) {
7767
+ GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7768
+ } else {
7769
+ GGML_ASSERT(src0->ne[0] == src1->ne[0]);
7770
+ GGML_ASSERT(src0->ne[0] == dst->ne[0]);
7771
+ GGML_ASSERT(src0->type == src1->type);
7772
+ }
7773
+
7774
+ const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
7775
+
7776
+ ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
7777
+ }
7778
+
7550
7779
  static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7551
7780
  int32_t * op_params = (int32_t *)dst->op_params;
7552
7781
  ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
@@ -7562,7 +7791,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7562
7791
  const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
7563
7792
  const uint32_t nrows_y = (uint32_t)src0->ne[1];
7564
7793
 
7565
- const uint32_t n_head_kv = nrows_x/nrows_y;
7794
+ const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
7795
+ const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
7796
+ const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
7797
+ const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
7798
+ const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
7799
+
7800
+ const uint32_t n_head_kv = src0->ne[2];
7566
7801
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
7567
7802
 
7568
7803
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -7571,6 +7806,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7571
7806
  ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
7572
7807
  ncols,
7573
7808
  src1 != nullptr ? nrows_y : (uint32_t)0,
7809
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
7810
+ ne12, ne13,
7811
+ nb11, nb12, nb13,
7574
7812
  scale, max_bias,
7575
7813
  m0, m1,
7576
7814
  n_head_log2,
@@ -8720,11 +8958,12 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
8720
8958
  }
8721
8959
  }
8722
8960
 
8723
- static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
8961
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
8724
8962
 
8725
8963
  // Returns true if node has enqueued work into the queue, false otherwise
8726
8964
  // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
8727
- static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
8965
+ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
8966
+ ggml_tensor * node = cgraph->nodes[node_idx];
8728
8967
  if (ggml_is_empty(node) || !node->buffer) {
8729
8968
  return false;
8730
8969
  }
@@ -8749,6 +8988,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8749
8988
  switch (ggml_get_unary_op(node)) {
8750
8989
  case GGML_UNARY_OP_SILU:
8751
8990
  case GGML_UNARY_OP_GELU:
8991
+ case GGML_UNARY_OP_GELU_ERF:
8752
8992
  case GGML_UNARY_OP_GELU_QUICK:
8753
8993
  case GGML_UNARY_OP_RELU:
8754
8994
  case GGML_UNARY_OP_TANH:
@@ -8758,6 +8998,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8758
8998
  return false;
8759
8999
  }
8760
9000
  break;
9001
+ case GGML_OP_GLU:
9002
+ switch (ggml_get_glu_op(node)) {
9003
+ case GGML_GLU_OP_GEGLU:
9004
+ case GGML_GLU_OP_REGLU:
9005
+ case GGML_GLU_OP_SWIGLU:
9006
+ case GGML_GLU_OP_GEGLU_ERF:
9007
+ case GGML_GLU_OP_GEGLU_QUICK:
9008
+ break;
9009
+ default:
9010
+ return false;
9011
+ }
9012
+ break;
8761
9013
  case GGML_OP_REPEAT:
8762
9014
  case GGML_OP_REPEAT_BACK:
8763
9015
  case GGML_OP_GET_ROWS:
@@ -8774,7 +9026,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8774
9026
  case GGML_OP_COS:
8775
9027
  case GGML_OP_CLAMP:
8776
9028
  case GGML_OP_PAD:
9029
+ case GGML_OP_ROLL:
8777
9030
  case GGML_OP_CPY:
9031
+ case GGML_OP_SET_ROWS:
8778
9032
  case GGML_OP_CONT:
8779
9033
  case GGML_OP_DUP:
8780
9034
  case GGML_OP_SILU_BACK:
@@ -8841,6 +9095,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8841
9095
  case GGML_OP_CLAMP:
8842
9096
  case GGML_OP_PAD:
8843
9097
  case GGML_OP_CPY:
9098
+ case GGML_OP_SET_ROWS:
8844
9099
  case GGML_OP_CONT:
8845
9100
  case GGML_OP_DUP:
8846
9101
  case GGML_OP_SILU_BACK:
@@ -8850,6 +9105,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8850
9105
  case GGML_OP_RMS_NORM_BACK:
8851
9106
  case GGML_OP_L2_NORM:
8852
9107
  case GGML_OP_UNARY:
9108
+ case GGML_OP_GLU:
8853
9109
  case GGML_OP_DIAG_MASK_INF:
8854
9110
  case GGML_OP_SOFT_MAX:
8855
9111
  case GGML_OP_SOFT_MAX_BACK:
@@ -8942,12 +9198,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8942
9198
  case GGML_OP_PAD:
8943
9199
  ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
8944
9200
 
9201
+ break;
9202
+ case GGML_OP_ROLL:
9203
+ ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
9204
+
8945
9205
  break;
8946
9206
  case GGML_OP_CPY:
8947
9207
  case GGML_OP_CONT:
8948
9208
  case GGML_OP_DUP:
8949
9209
  ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
8950
9210
 
9211
+ break;
9212
+ case GGML_OP_SET_ROWS:
9213
+ ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun);
9214
+
8951
9215
  break;
8952
9216
  case GGML_OP_SILU_BACK:
8953
9217
  ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -8962,8 +9226,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8962
9226
 
8963
9227
  break;
8964
9228
  case GGML_OP_RMS_NORM:
8965
- ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
8966
-
9229
+ if (ctx->num_additional_fused_ops > 0) {
9230
+ // fused rms_norm + mul
9231
+ ggml_tensor *mul = cgraph->nodes[node_idx + 1];
9232
+ ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
9233
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
9234
+ } else {
9235
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
9236
+ }
8967
9237
  break;
8968
9238
  case GGML_OP_RMS_NORM_BACK:
8969
9239
  ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -8977,6 +9247,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8977
9247
  switch (ggml_get_unary_op(node)) {
8978
9248
  case GGML_UNARY_OP_SILU:
8979
9249
  case GGML_UNARY_OP_GELU:
9250
+ case GGML_UNARY_OP_GELU_ERF:
8980
9251
  case GGML_UNARY_OP_GELU_QUICK:
8981
9252
  case GGML_UNARY_OP_RELU:
8982
9253
  case GGML_UNARY_OP_TANH:
@@ -8987,6 +9258,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8987
9258
  return false;
8988
9259
  }
8989
9260
  break;
9261
+ case GGML_OP_GLU:
9262
+ switch (ggml_get_glu_op(node)) {
9263
+ case GGML_GLU_OP_GEGLU:
9264
+ case GGML_GLU_OP_REGLU:
9265
+ case GGML_GLU_OP_SWIGLU:
9266
+ case GGML_GLU_OP_GEGLU_ERF:
9267
+ case GGML_GLU_OP_GEGLU_QUICK:
9268
+ ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
9269
+ break;
9270
+ default:
9271
+ return false;
9272
+ }
9273
+ break;
8990
9274
  case GGML_OP_DIAG_MASK_INF:
8991
9275
  ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
8992
9276
 
@@ -9108,12 +9392,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
9108
9392
 
9109
9393
  ctx->compute_ctx.reset();
9110
9394
 
9111
- bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
9395
+ bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
9112
9396
  if (!ok) {
9113
9397
  if (node->op == GGML_OP_UNARY) {
9114
9398
  std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
9115
- }
9116
- else {
9399
+ } else if (node->op == GGML_OP_GLU) {
9400
+ std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
9401
+ } else {
9117
9402
  std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
9118
9403
  }
9119
9404
  }
@@ -9122,7 +9407,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
9122
9407
  return true;
9123
9408
  }
9124
9409
 
9125
- static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9410
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
9411
+ GGML_UNUSED(cgraph);
9126
9412
  ggml_backend_buffer * buf = nullptr;
9127
9413
 
9128
9414
  switch (tensor->op) {
@@ -9140,7 +9426,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9140
9426
  case GGML_OP_COS:
9141
9427
  case GGML_OP_CLAMP:
9142
9428
  case GGML_OP_PAD:
9429
+ case GGML_OP_ROLL:
9143
9430
  case GGML_OP_CPY:
9431
+ case GGML_OP_SET_ROWS:
9144
9432
  case GGML_OP_CONT:
9145
9433
  case GGML_OP_DUP:
9146
9434
  case GGML_OP_SILU_BACK:
@@ -9182,6 +9470,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9182
9470
  switch (ggml_get_unary_op(tensor)) {
9183
9471
  case GGML_UNARY_OP_SILU:
9184
9472
  case GGML_UNARY_OP_GELU:
9473
+ case GGML_UNARY_OP_GELU_ERF:
9185
9474
  case GGML_UNARY_OP_GELU_QUICK:
9186
9475
  case GGML_UNARY_OP_RELU:
9187
9476
  case GGML_UNARY_OP_TANH:
@@ -9192,6 +9481,19 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9192
9481
  return false;
9193
9482
  }
9194
9483
  break;
9484
+ case GGML_OP_GLU:
9485
+ switch (ggml_get_glu_op(tensor)) {
9486
+ case GGML_GLU_OP_GEGLU:
9487
+ case GGML_GLU_OP_REGLU:
9488
+ case GGML_GLU_OP_SWIGLU:
9489
+ case GGML_GLU_OP_GEGLU_ERF:
9490
+ case GGML_GLU_OP_GEGLU_QUICK:
9491
+ buf = tensor->buffer;
9492
+ break;
9493
+ default:
9494
+ return false;
9495
+ }
9496
+ break;
9195
9497
  case GGML_OP_MUL_MAT:
9196
9498
  case GGML_OP_MUL_MAT_ID:
9197
9499
  case GGML_OP_FLASH_ATTN_EXT:
@@ -9218,7 +9520,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9218
9520
  // Only run if ctx hasn't been submitted yet
9219
9521
  if (!subctx->seqs.empty()) {
9220
9522
  #ifdef GGML_VULKAN_CHECK_RESULTS
9221
- ggml_vk_check_results_0(tensor);
9523
+ ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
9222
9524
  use_fence = true;
9223
9525
  #endif
9224
9526
 
@@ -9238,7 +9540,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9238
9540
  ggml_vk_wait_for_fence(ctx);
9239
9541
  }
9240
9542
  #ifdef GGML_VULKAN_CHECK_RESULTS
9241
- ggml_vk_check_results_1(tensor);
9543
+ ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
9242
9544
  #endif
9243
9545
  }
9244
9546
 
@@ -9685,6 +9987,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
9685
9987
  return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
9686
9988
  }
9687
9989
 
9990
+ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
9991
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
9992
+ return false;
9993
+ }
9994
+
9995
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
9996
+ // additional constraints specific to this fusion
9997
+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
9998
+ const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
9999
+
10000
+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
10001
+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
10002
+ // rms_norm only supports f32
10003
+ if (mul->src[0]->type != GGML_TYPE_F32 ||
10004
+ mul->src[1]->type != GGML_TYPE_F32 ||
10005
+ mul->type != GGML_TYPE_F32) {
10006
+ return false;
10007
+ }
10008
+ // if rms_norm is the B operand, then we don't handle broadcast
10009
+ if (rms_norm == mul->src[1] &&
10010
+ mul->src[0]->ne[1] != rms_norm->ne[1]) {
10011
+ return false;
10012
+ }
10013
+ // rms_norm shader assumes contiguous rows
10014
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
10015
+ return false;
10016
+ }
10017
+ }
10018
+ return true;
10019
+ }
10020
+
9688
10021
  static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
9689
10022
  VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
9690
10023
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -9698,10 +10031,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9698
10031
 
9699
10032
  uint64_t total_mat_mul_bytes = 0;
9700
10033
  for (int i = 0; i < cgraph->n_nodes; i++) {
9701
- ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
10034
+ if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10035
+ ctx->num_additional_fused_ops = 1;
10036
+ }
10037
+ ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
9702
10038
  if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
9703
10039
  total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
9704
10040
  }
10041
+ i += ctx->num_additional_fused_ops;
10042
+ ctx->num_additional_fused_ops = 0;
9705
10043
  }
9706
10044
  if (ctx->device->need_compiles) {
9707
10045
  ggml_vk_load_shaders(ctx->device);
@@ -9763,14 +10101,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9763
10101
  mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
9764
10102
  }
9765
10103
 
10104
+ if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10105
+ ctx->num_additional_fused_ops = 1;
10106
+ }
10107
+
9766
10108
  // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
9767
10109
  bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
9768
10110
  bool submit = (submitted_nodes >= nodes_per_submit) ||
9769
10111
  (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
9770
- (i == last_node) ||
10112
+ (i + ctx->num_additional_fused_ops == last_node) ||
9771
10113
  (almost_ready && !ctx->almost_ready_fence_pending);
9772
10114
 
9773
- bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
10115
+ bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
9774
10116
 
9775
10117
  if (vk_perf_logger_enabled) {
9776
10118
  if (ctx->compute_ctx.expired()) {
@@ -9780,7 +10122,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9780
10122
  } else {
9781
10123
  compute_ctx = ctx->compute_ctx.lock();
9782
10124
  }
9783
- compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
10125
+ // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
10126
+ for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
10127
+ compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
10128
+ }
9784
10129
  }
9785
10130
 
9786
10131
  if (enqueued) {
@@ -9802,6 +10147,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9802
10147
  }
9803
10148
  submit_count++;
9804
10149
  }
10150
+ i += ctx->num_additional_fused_ops;
10151
+ ctx->num_additional_fused_ops = 0;
9805
10152
  }
9806
10153
 
9807
10154
  if (vk_perf_logger_enabled) {
@@ -9963,6 +10310,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9963
10310
  case GGML_OP_UNARY:
9964
10311
  switch (ggml_get_unary_op(op)) {
9965
10312
  case GGML_UNARY_OP_GELU:
10313
+ case GGML_UNARY_OP_GELU_ERF:
9966
10314
  case GGML_UNARY_OP_GELU_QUICK:
9967
10315
  case GGML_UNARY_OP_SILU:
9968
10316
  case GGML_UNARY_OP_RELU:
@@ -9976,15 +10324,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9976
10324
  return false;
9977
10325
  }
9978
10326
  break;
10327
+ case GGML_OP_GLU:
10328
+ switch (ggml_get_glu_op(op)) {
10329
+ case GGML_GLU_OP_GEGLU:
10330
+ case GGML_GLU_OP_REGLU:
10331
+ case GGML_GLU_OP_SWIGLU:
10332
+ case GGML_GLU_OP_GEGLU_ERF:
10333
+ case GGML_GLU_OP_GEGLU_QUICK:
10334
+ return ggml_is_contiguous(op->src[0]) &&
10335
+ (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10336
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
10337
+ (op->src[0]->type == op->type);
10338
+ default:
10339
+ return false;
10340
+ }
10341
+ break;
9979
10342
  case GGML_OP_MUL_MAT:
9980
10343
  case GGML_OP_MUL_MAT_ID:
9981
10344
  {
9982
10345
  ggml_type src0_type = op->src[0]->type;
9983
10346
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
9984
10347
  const vk_device& device = ggml_vk_get_device(ctx->device);
9985
- if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
9986
- // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
9987
- return false;
10348
+ if (op->op == GGML_OP_MUL_MAT_ID) {
10349
+ if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
10350
+ // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
10351
+ return false;
10352
+ }
10353
+ // Check against size of shared memory variable
10354
+ if (op->src[2]->ne[0] > 4096) {
10355
+ return false;
10356
+ }
9988
10357
  }
9989
10358
  switch (src0_type) {
9990
10359
  case GGML_TYPE_F32:
@@ -10042,19 +10411,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10042
10411
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
10043
10412
  auto device = ggml_vk_get_device(ctx->device);
10044
10413
  bool coopmat2 = device->coopmat2;
10045
- switch (op->src[0]->ne[0]) {
10046
- case 64:
10047
- case 80:
10048
- case 96:
10049
- case 112:
10050
- case 128:
10051
- case 256:
10052
- break;
10053
- default:
10054
- return false;
10055
- }
10056
- if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
10057
- // different head sizes of K and V are not supported yet
10414
+ FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
10415
+ if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
10058
10416
  return false;
10059
10417
  }
10060
10418
  if (op->src[0]->type != GGML_TYPE_F32) {
@@ -10134,6 +10492,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10134
10492
  return false;
10135
10493
  }
10136
10494
  } break;
10495
+ case GGML_OP_SET_ROWS:
10496
+ {
10497
+ switch (op->type) {
10498
+ case GGML_TYPE_F32:
10499
+ case GGML_TYPE_F16:
10500
+ case GGML_TYPE_BF16:
10501
+ case GGML_TYPE_Q4_0:
10502
+ case GGML_TYPE_Q4_1:
10503
+ case GGML_TYPE_Q5_0:
10504
+ case GGML_TYPE_Q5_1:
10505
+ case GGML_TYPE_Q8_0:
10506
+ case GGML_TYPE_IQ4_NL:
10507
+ return true;
10508
+ default:
10509
+ return false;
10510
+ }
10511
+ } break;
10137
10512
  case GGML_OP_CONT:
10138
10513
  case GGML_OP_CPY:
10139
10514
  case GGML_OP_DUP:
@@ -10218,11 +10593,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10218
10593
  case GGML_OP_CLAMP:
10219
10594
  return op->src[0]->type == GGML_TYPE_F32;
10220
10595
  case GGML_OP_UPSCALE:
10221
- return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
10222
10596
  case GGML_OP_ACC:
10223
10597
  case GGML_OP_CONCAT:
10224
10598
  case GGML_OP_SCALE:
10225
10599
  case GGML_OP_PAD:
10600
+ case GGML_OP_ROLL:
10226
10601
  case GGML_OP_DIAG_MASK_INF:
10227
10602
  case GGML_OP_SOFT_MAX:
10228
10603
  case GGML_OP_SOFT_MAX_BACK:
@@ -10513,11 +10888,21 @@ void * comp_result;
10513
10888
  size_t comp_size;
10514
10889
  size_t comp_nb[GGML_MAX_DIMS];
10515
10890
  size_t check_counter = 0;
10516
- static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10891
+ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
10892
+ ggml_tensor * tensor = cgraph->nodes[tensor_idx];
10517
10893
  if (tensor->op == GGML_OP_TRANSPOSE) {
10518
10894
  return;
10519
10895
  }
10520
10896
 
10897
+ bool fused_rms_norm_mul = false;
10898
+ int rms_norm_idx = -1;
10899
+ if (ctx->num_additional_fused_ops == 1 &&
10900
+ tensor->op == GGML_OP_RMS_NORM &&
10901
+ cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
10902
+ fused_rms_norm_mul = true;
10903
+ tensor = cgraph->nodes[tensor_idx + 1];
10904
+ }
10905
+
10521
10906
  check_counter++;
10522
10907
  if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
10523
10908
  return;
@@ -10545,6 +10930,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10545
10930
 
10546
10931
  for (int i = 0; i < 6; i++) {
10547
10932
  ggml_tensor * srci = tensor->src[i];
10933
+ if (fused_rms_norm_mul) {
10934
+ rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
10935
+ ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
10936
+ switch (i) {
10937
+ case 0: srci = rms_norm->src[0]; break;
10938
+ case 1: srci = tensor->src[1 - rms_norm_idx]; break;
10939
+ default: continue;
10940
+ }
10941
+ }
10548
10942
  if (srci == nullptr) {
10549
10943
  continue;
10550
10944
  }
@@ -10602,7 +10996,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10602
10996
  } else if (tensor->op == GGML_OP_SUB) {
10603
10997
  tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
10604
10998
  } else if (tensor->op == GGML_OP_MUL) {
10605
- tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
10999
+ if (fused_rms_norm_mul) {
11000
+ tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
11001
+ tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
11002
+ } else {
11003
+ tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
11004
+ }
10606
11005
  } else if (tensor->op == GGML_OP_DIV) {
10607
11006
  tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
10608
11007
  } else if (tensor->op == GGML_OP_CONCAT) {
@@ -10690,6 +11089,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10690
11089
  case GGML_UNARY_OP_GELU:
10691
11090
  tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
10692
11091
  break;
11092
+ case GGML_UNARY_OP_GELU_ERF:
11093
+ tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
11094
+ break;
10693
11095
  case GGML_UNARY_OP_GELU_QUICK:
10694
11096
  tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
10695
11097
  break;
@@ -10706,6 +11108,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10706
11108
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
10707
11109
  GGML_ABORT("fatal error");
10708
11110
  }
11111
+ } else if (tensor->op == GGML_OP_GLU) {
11112
+ if (src_clone[1] == nullptr) {
11113
+ tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
11114
+ } else {
11115
+ tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
11116
+ }
10709
11117
  } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
10710
11118
  if (src1 == nullptr) {
10711
11119
  tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
@@ -10713,6 +11121,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10713
11121
  } else {
10714
11122
  tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
10715
11123
  }
11124
+ } else if (tensor->op == GGML_OP_SET_ROWS) {
11125
+ tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
10716
11126
  } else if (tensor->op == GGML_OP_CONT) {
10717
11127
  tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
10718
11128
  } else if (tensor->op == GGML_OP_RESHAPE) {
@@ -10784,10 +11194,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10784
11194
  GGML_ABORT("fatal error");
10785
11195
  }
10786
11196
 
10787
- ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
10788
- ggml_build_forward_expand(cgraph, tensor_clone);
11197
+ ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
11198
+ ggml_build_forward_expand(cgraph_cpu, tensor_clone);
10789
11199
 
10790
- ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
11200
+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
10791
11201
 
10792
11202
  if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
10793
11203
  ggml_vk_print_tensor(tensor_clone, "tensor_clone");
@@ -10810,10 +11220,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10810
11220
  VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
10811
11221
  }
10812
11222
 
10813
- static void ggml_vk_check_results_1(ggml_tensor * tensor) {
11223
+ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
11224
+ ggml_tensor * tensor = cgraph->nodes[tensor_idx];
10814
11225
  if (tensor->op == GGML_OP_TRANSPOSE) {
10815
11226
  return;
10816
11227
  }
11228
+ bool fused_rms_norm_mul = false;
11229
+ if (ctx->num_additional_fused_ops == 1 &&
11230
+ tensor->op == GGML_OP_RMS_NORM &&
11231
+ cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
11232
+ fused_rms_norm_mul = true;
11233
+ tensor = cgraph->nodes[tensor_idx + 1];
11234
+ }
11235
+
10817
11236
  if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
10818
11237
  return;
10819
11238
  }