@novastera-oss/llamarn 0.2.1 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. package/README.md +80 -14
  2. package/RNLlamaCpp.podspec +10 -3
  3. package/android/CMakeLists.txt +8 -0
  4. package/android/src/main/cpp/include/llama.h +62 -125
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  13. package/cpp/build-info.cpp +2 -2
  14. package/cpp/llama.cpp/README.md +11 -3
  15. package/cpp/llama.cpp/build-xcframework.sh +1 -0
  16. package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
  17. package/cpp/llama.cpp/common/arg.cpp +153 -113
  18. package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
  19. package/cpp/llama.cpp/common/chat-parser.h +117 -0
  20. package/cpp/llama.cpp/common/chat.cpp +847 -699
  21. package/cpp/llama.cpp/common/chat.h +73 -6
  22. package/cpp/llama.cpp/common/common.cpp +50 -82
  23. package/cpp/llama.cpp/common/common.h +21 -17
  24. package/cpp/llama.cpp/common/json-partial.cpp +255 -0
  25. package/cpp/llama.cpp/common/json-partial.h +37 -0
  26. package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
  27. package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
  28. package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
  29. package/cpp/llama.cpp/common/regex-partial.h +56 -0
  30. package/cpp/llama.cpp/common/sampling.cpp +7 -8
  31. package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
  32. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
  33. package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
  34. package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
  35. package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
  36. package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
  37. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
  74. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  117. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
  118. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
  119. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
  120. package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
  121. package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
  122. package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
  123. package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
  124. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
  125. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
  126. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
  127. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
  128. package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
  129. package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
  130. package/cpp/llama.cpp/include/llama.h +62 -125
  131. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
  132. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
  133. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
  134. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
  135. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
  136. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
  137. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
  138. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
  139. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
  140. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
  141. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
  142. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
  143. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
  144. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
  145. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
  146. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
  147. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
  148. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
  149. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  150. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
  151. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
  152. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
  153. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
  154. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
  155. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
  156. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
  157. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
  158. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
  159. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
  160. package/cpp/llama.cpp/models/templates/README.md +2 -0
  161. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
  162. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
  163. package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  164. package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
  165. package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
  166. package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
  167. package/cpp/llama.cpp/src/llama-arch.h +2 -0
  168. package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
  169. package/cpp/llama.cpp/src/llama-context.cpp +340 -123
  170. package/cpp/llama.cpp/src/llama-context.h +30 -0
  171. package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
  172. package/cpp/llama.cpp/src/llama-cparams.h +2 -0
  173. package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
  174. package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
  175. package/cpp/llama.cpp/src/llama-graph.h +52 -7
  176. package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
  177. package/cpp/llama.cpp/src/llama-hparams.h +37 -5
  178. package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
  179. package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
  180. package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
  181. package/cpp/llama.cpp/src/llama-memory.h +4 -3
  182. package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
  183. package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
  184. package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
  185. package/cpp/llama.cpp/src/llama-model.cpp +529 -172
  186. package/cpp/llama.cpp/src/llama-model.h +6 -1
  187. package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
  188. package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
  189. package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
  190. package/cpp/llama.cpp/src/llama-vocab.h +6 -0
  191. package/cpp/llama.cpp/src/llama.cpp +14 -0
  192. package/cpp/rn-completion.cpp +4 -2
  193. package/ios/include/chat.h +73 -6
  194. package/ios/include/common/minja/chat-template.hpp +9 -5
  195. package/ios/include/common/minja/minja.hpp +69 -36
  196. package/ios/include/common.h +21 -17
  197. package/ios/include/llama.h +62 -125
  198. package/ios/libs/llama.xcframework/Info.plist +19 -19
  199. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  200. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
  201. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  202. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
  203. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
  204. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  205. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  206. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  207. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
  208. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  209. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  210. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  211. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
  219. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
  220. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
  221. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  222. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
  223. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
  224. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
  225. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  226. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  227. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  228. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
  229. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
  231. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
  232. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
  235. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
  236. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  237. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  238. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  239. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  240. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  241. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
  242. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  243. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
  244. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
  245. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  246. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  247. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
  248. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
  249. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  250. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  251. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  252. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  253. package/package.json +1 -1
  254. package/cpp/llama.cpp/common/stb_image.h +0 -7988
  255. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
  256. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
  257. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
  258. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
  259. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
  260. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
  261. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
  262. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
  263. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
  264. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
  265. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
  266. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
@@ -1,6 +1,6 @@
1
1
  #include "ggml-vulkan.h"
2
2
  #include <vulkan/vulkan_core.h>
3
- #if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS)
3
+ #if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS)
4
4
  #include <chrono>
5
5
  #include "ggml-cpu.h"
6
6
  #endif
@@ -184,9 +184,7 @@ static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
184
184
  #ifdef GGML_VULKAN_MEMORY_DEBUG
185
185
  class vk_memory_logger;
186
186
  #endif
187
- #ifdef GGML_VULKAN_PERF
188
187
  class vk_perf_logger;
189
- #endif
190
188
  static void ggml_vk_destroy_buffer(vk_buffer& buf);
191
189
 
192
190
  static constexpr uint32_t mul_mat_vec_max_cols = 8;
@@ -288,6 +286,9 @@ struct vk_device_struct {
288
286
  bool coopmat_acc_f32_support {};
289
287
  bool coopmat_acc_f16_support {};
290
288
  bool coopmat_bf16_support {};
289
+ bool coopmat_support_16x16x16_f16acc {};
290
+ bool coopmat_support_16x16x16_f32acc {};
291
+ bool coopmat1_fa_support {};
291
292
  uint32_t coopmat_m;
292
293
  uint32_t coopmat_n;
293
294
  uint32_t coopmat_k;
@@ -410,6 +411,13 @@ struct vk_device_struct {
410
411
  vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
411
412
  vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
412
413
 
414
+ vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
415
+ vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
416
+ vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
417
+ vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
418
+ vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
419
+ vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
420
+
413
421
  vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
414
422
  vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
415
423
  vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
@@ -432,9 +440,11 @@ struct vk_device_struct {
432
440
  #ifdef GGML_VULKAN_MEMORY_DEBUG
433
441
  std::unique_ptr<vk_memory_logger> memory_logger;
434
442
  #endif
435
- #ifdef GGML_VULKAN_PERF
443
+
444
+ // for GGML_VK_PERF_LOGGER
436
445
  std::unique_ptr<vk_perf_logger> perf_logger;
437
- #endif
446
+ vk::QueryPool query_pool;
447
+ uint32_t num_queries;
438
448
 
439
449
  ~vk_device_struct() {
440
450
  VK_LOG_DEBUG("destroy device " << name);
@@ -818,8 +828,6 @@ private:
818
828
  #define VK_LOG_MEMORY(msg) ((void) 0)
819
829
  #endif // GGML_VULKAN_MEMORY_DEBUG
820
830
 
821
- #if defined(GGML_VULKAN_PERF)
822
-
823
831
  class vk_perf_logger {
824
832
  public:
825
833
  void print_timings() {
@@ -829,7 +837,7 @@ public:
829
837
  for (const auto& time : t.second) {
830
838
  total += time;
831
839
  }
832
- std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " ms" << std::endl;
840
+ std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl;
833
841
  }
834
842
 
835
843
  timings.clear();
@@ -858,7 +866,6 @@ public:
858
866
  private:
859
867
  std::map<std::string, std::vector<uint64_t>> timings;
860
868
  };
861
- #endif // GGML_VULKAN_PERF
862
869
 
863
870
  struct ggml_backend_vk_context {
864
871
  std::string name;
@@ -948,6 +955,8 @@ struct vk_instance_t {
948
955
  static bool vk_instance_initialized = false;
949
956
  static vk_instance_t vk_instance;
950
957
 
958
+ static bool vk_perf_logger_enabled = false;
959
+
951
960
  #ifdef GGML_VULKAN_CHECK_RESULTS
952
961
  static size_t vk_skip_checks;
953
962
  static size_t vk_output_tensor;
@@ -1588,19 +1597,36 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
1588
1597
  );
1589
1598
  }
1590
1599
 
1600
+ enum FaCodePath {
1601
+ FA_SCALAR,
1602
+ FA_COOPMAT1,
1603
+ FA_COOPMAT2,
1604
+ };
1605
+
1591
1606
  // number of rows/cols for flash attention shader
1592
1607
  static constexpr uint32_t flash_attention_num_small_rows = 32;
1593
1608
  static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
1594
1609
  static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
1595
1610
 
1596
- static uint32_t get_fa_num_small_rows(bool scalar) {
1597
- return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
1611
+ // The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
1612
+ // 128 threads split into four subgroups, each subgroup does 1/4
1613
+ // of the Bc dimension.
1614
+ static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
1615
+ static constexpr uint32_t scalar_flash_attention_Bc = 64;
1616
+ static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
1617
+
1618
+ static uint32_t get_fa_num_small_rows(FaCodePath path) {
1619
+ if (path == FA_COOPMAT2) {
1620
+ return flash_attention_num_small_rows;
1621
+ } else {
1622
+ return scalar_flash_attention_num_small_rows;
1623
+ }
1598
1624
  }
1599
1625
 
1600
- static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
1626
+ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
1601
1627
  GGML_UNUSED(clamp);
1602
1628
 
1603
- if (scalar) {
1629
+ if (path == FA_SCALAR) {
1604
1630
  if (small_rows) {
1605
1631
  return {scalar_flash_attention_num_small_rows, 64};
1606
1632
  } else {
@@ -1608,9 +1634,17 @@ static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t cl
1608
1634
  }
1609
1635
  }
1610
1636
 
1637
+ if (path == FA_COOPMAT1) {
1638
+ if (small_rows) {
1639
+ return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
1640
+ } else {
1641
+ return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
1642
+ }
1643
+ }
1644
+
1611
1645
  // small rows, large cols
1612
1646
  if (small_rows) {
1613
- return {get_fa_num_small_rows(scalar), 32};
1647
+ return {get_fa_num_small_rows(FA_COOPMAT2), 32};
1614
1648
  }
1615
1649
 
1616
1650
  // small cols to reduce register count
@@ -1907,17 +1941,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
1907
1941
  parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
1908
1942
  };
1909
1943
 
1910
- auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
1911
- return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1};
1944
+ auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
1945
+ return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
1912
1946
  };
1913
1947
 
1914
- auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
1948
+ auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
1915
1949
  // For large number of rows, 128 invocations seems to work best.
1916
1950
  // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
1917
1951
  // can't use 256 for D==80.
1918
1952
  // For scalar, use 128 (arbitrary)
1919
- uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128);
1920
- auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows);
1953
+ uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
1954
+ ? scalar_flash_attention_workgroup_size
1955
+ : ((small_rows && (D % 32) == 0) ? 256 : 128);
1956
+ auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
1921
1957
 
1922
1958
  // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
1923
1959
  // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@@ -1929,36 +1965,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
1929
1965
  return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
1930
1966
  };
1931
1967
 
1932
- #define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \
1933
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
1934
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
1935
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
1936
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
1937
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
1938
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
1939
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
1940
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
1941
-
1942
- #define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \
1943
- CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \
1944
- CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \
1945
- CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \
1946
- CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \
1947
- CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \
1948
- CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256)
1949
-
1950
- CREATE_FA(GGML_TYPE_F16, f16, true, )
1951
- CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, )
1952
- CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, )
1968
+ #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
1969
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1970
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1971
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1972
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1973
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1974
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1975
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1976
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1977
+
1978
+ #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
1979
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
1980
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
1981
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
1982
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
1983
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
1984
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
1985
+
1986
+ CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
1987
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
1988
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
1989
+ #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
1990
+ if (device->coopmat1_fa_support) {
1991
+ CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
1992
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
1993
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
1994
+ }
1995
+ #endif
1953
1996
  #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
1954
1997
  if (device->coopmat2) {
1955
- CREATE_FA(GGML_TYPE_F16, f16, false, _cm2)
1956
- CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2)
1957
- CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2)
1958
- CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2)
1959
- CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2)
1960
- CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2)
1961
- CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2)
1998
+ CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
1999
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
2000
+ CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
2001
+ CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
2002
+ CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
2003
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
2004
+ CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
1962
2005
  }
1963
2006
  #endif
1964
2007
  #undef CREATE_FA2
@@ -1987,25 +2030,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
1987
2030
  CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1988
2031
  }
1989
2032
  #endif
1990
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1991
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1992
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1993
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1994
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1995
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1996
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1997
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1998
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1999
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
2000
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2001
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2002
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2003
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2004
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2005
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2006
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2007
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2008
- CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2033
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2034
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2035
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2036
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2037
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2038
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
2039
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
2040
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
2041
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
2042
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
2043
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S], matmul_iq1_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2044
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M], matmul_iq1_m_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2045
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2046
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2047
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S], matmul_iq2_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2048
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2049
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2050
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2051
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
2009
2052
 
2010
2053
  CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
2011
2054
  #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
@@ -2041,17 +2084,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
2041
2084
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
2042
2085
  #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2043
2086
  if (device->mul_mat ## ID ## _l[TYPE]) \
2044
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
2087
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
2045
2088
  if (device->mul_mat ## ID ## _m[TYPE]) \
2046
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
2089
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
2047
2090
  if (device->mul_mat ## ID ## _s[TYPE]) \
2048
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
2091
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
2049
2092
  if (device->mul_mat ## ID ## _l[TYPE]) \
2050
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
2093
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
2051
2094
  if (device->mul_mat ## ID ## _m[TYPE]) \
2052
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
2095
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
2053
2096
  if (device->mul_mat ## ID ## _s[TYPE]) \
2054
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
2097
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
2055
2098
 
2056
2099
  // Create 2 variants, {f16,f32} accumulator
2057
2100
  #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -2073,47 +2116,47 @@ static void ggml_vk_load_shaders(vk_device& device) {
2073
2116
  #endif
2074
2117
 
2075
2118
  if (device->coopmat_acc_f16_support) {
2076
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2077
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2078
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2079
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2080
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2081
-
2082
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2083
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2084
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2085
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2086
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2087
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2088
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2089
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2090
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2091
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2092
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2093
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2094
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2095
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2119
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2120
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2121
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2122
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2123
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2124
+
2125
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2126
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2127
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2128
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2129
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2130
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2131
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2132
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2133
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2134
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2135
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2136
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2137
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2138
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2096
2139
  } else {
2097
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2098
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2099
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2100
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2101
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2102
-
2103
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2104
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2105
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2106
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2107
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2108
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2109
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2110
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2111
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2112
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2113
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2114
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2115
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2116
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2140
+ CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2141
+ CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2142
+ CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2143
+ CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2144
+ CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2145
+
2146
+ CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2147
+ CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2148
+ CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2149
+ CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2150
+ CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2151
+ CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2152
+ CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2153
+ CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2154
+ CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2155
+ CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2156
+ CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2157
+ CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2158
+ CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2159
+ CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2117
2160
  }
2118
2161
 
2119
2162
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2188,13 +2231,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
2188
2231
  if (device->mul_mat ## ID ## _s[TYPE]) \
2189
2232
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2190
2233
 
2191
- #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2192
- if (device->mul_mat ## ID ## _l[TYPE]) \
2193
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2194
- if (device->mul_mat ## ID ## _m[TYPE]) \
2195
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2196
- if (device->mul_mat ## ID ## _s[TYPE]) \
2197
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2234
+ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2235
+ if (device->mul_mat ## ID ## _l[TYPE]) { \
2236
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2237
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2238
+ } \
2239
+ if (device->mul_mat ## ID ## _m[TYPE]) { \
2240
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2241
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2242
+ } \
2243
+ if (device->mul_mat ## ID ## _s[TYPE]) { \
2244
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2245
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2246
+ } \
2198
2247
 
2199
2248
  // Create 2 variants, {f16,f32} accumulator
2200
2249
  #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -2208,34 +2257,34 @@ static void ggml_vk_load_shaders(vk_device& device) {
2208
2257
 
2209
2258
  CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2210
2259
 
2211
- CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2212
- CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2213
- CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2214
- CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2215
- CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2216
-
2217
- CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2218
- CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2219
- CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2220
- CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2221
- CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2222
- CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2223
- CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2224
- CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2225
- CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2226
- CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2227
- CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2228
- CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2229
- CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2230
- CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2260
+ CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2261
+ CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2262
+ CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2263
+ CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2264
+ CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2265
+
2266
+ CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2267
+ CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2268
+ CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2269
+ CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2270
+ CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2271
+ CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2272
+ CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2273
+ CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2274
+ CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2275
+ CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2276
+ CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2277
+ CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2278
+ CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2279
+ CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2231
2280
 
2232
2281
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2233
2282
  if (device->integer_dot_product) {
2234
- CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2235
- CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2236
- CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2237
- CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2238
- CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2283
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2284
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2285
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2286
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2287
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2239
2288
  }
2240
2289
  #endif
2241
2290
 
@@ -2284,13 +2333,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
2284
2333
  if (device->mul_mat ## ID ## _s[TYPE]) \
2285
2334
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2286
2335
 
2287
- #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2336
+ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2288
2337
  if (device->mul_mat ## ID ## _l[TYPE]) \
2289
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2338
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2290
2339
  if (device->mul_mat ## ID ## _m[TYPE]) \
2291
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2340
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2292
2341
  if (device->mul_mat ## ID ## _s[TYPE]) \
2293
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2342
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2294
2343
 
2295
2344
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2296
2345
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
@@ -2322,11 +2371,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
2322
2371
 
2323
2372
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2324
2373
  if (device->integer_dot_product) {
2325
- CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2326
- CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2327
- CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2328
- CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2329
- CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2374
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2375
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2376
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2377
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2378
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2330
2379
  }
2331
2380
  #endif
2332
2381
 
@@ -2707,9 +2756,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2707
2756
  #ifdef GGML_VULKAN_MEMORY_DEBUG
2708
2757
  device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());
2709
2758
  #endif
2710
- #ifdef GGML_VULKAN_PERF
2711
- device->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
2712
- #endif
2759
+ if (vk_perf_logger_enabled) {
2760
+ device->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
2761
+ }
2713
2762
 
2714
2763
  size_t dev_num = vk_instance.device_indices[idx];
2715
2764
 
@@ -2754,23 +2803,29 @@ static vk_device ggml_vk_get_device(size_t idx) {
2754
2803
  pipeline_robustness = true;
2755
2804
  } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
2756
2805
  device->subgroup_size_control = true;
2806
+ #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
2757
2807
  } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
2758
2808
  !getenv("GGML_VK_DISABLE_COOPMAT")) {
2759
2809
  device->coopmat_support = true;
2760
2810
  device->coopmat_m = 0;
2761
2811
  device->coopmat_n = 0;
2762
2812
  device->coopmat_k = 0;
2813
+ #endif
2814
+ #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
2763
2815
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2764
2816
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2765
2817
  coopmat2_support = true;
2818
+ #endif
2766
2819
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2767
2820
  } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
2768
2821
  !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
2769
2822
  device->integer_dot_product = true;
2770
2823
  #endif
2824
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2771
2825
  } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
2772
2826
  !getenv("GGML_VK_DISABLE_BFLOAT16")) {
2773
2827
  bfloat16_support = true;
2828
+ #endif
2774
2829
  }
2775
2830
  }
2776
2831
 
@@ -3009,6 +3064,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
3009
3064
 
3010
3065
  #if defined(VK_KHR_cooperative_matrix)
3011
3066
  device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
3067
+
3068
+ // coopmat1 fa shader currently assumes 32 invocations per subgroup
3069
+ device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
3070
+ device->subgroup_size_control && device->subgroup_min_size <= 32 &&
3071
+ device->subgroup_max_size >= 32;
3012
3072
  #endif
3013
3073
 
3014
3074
  if (coopmat2_support) {
@@ -3143,6 +3203,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
3143
3203
  // Only enable if shape is identical
3144
3204
  device->coopmat_acc_f32_support = true;
3145
3205
  }
3206
+ if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
3207
+ device->coopmat_support_16x16x16_f32acc = true;
3208
+ }
3146
3209
  } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
3147
3210
  (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
3148
3211
  // coopmat sizes not set yet
@@ -3155,6 +3218,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
3155
3218
  // Only enable if shape is identical
3156
3219
  device->coopmat_acc_f16_support = true;
3157
3220
  }
3221
+ if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
3222
+ device->coopmat_support_16x16x16_f16acc = true;
3223
+ }
3158
3224
  }
3159
3225
  } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
3160
3226
  (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
@@ -3480,6 +3546,8 @@ static void ggml_vk_instance_init() {
3480
3546
  vk_instance.instance = vk::createInstance(instance_create_info);
3481
3547
  vk_instance_initialized = true;
3482
3548
 
3549
+ vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
3550
+
3483
3551
  size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
3484
3552
 
3485
3553
  // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
@@ -3656,7 +3724,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
3656
3724
  }
3657
3725
 
3658
3726
  static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
3659
- VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
3727
+ VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ", " << prec << ")");
3660
3728
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
3661
3729
  return ctx->device->pipeline_matmul_f32;
3662
3730
  }
@@ -3684,7 +3752,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
3684
3752
 
3685
3753
  // MMQ
3686
3754
  if (src1_type == GGML_TYPE_Q8_1) {
3687
- vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc;
3755
+ vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
3688
3756
 
3689
3757
  if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
3690
3758
  return nullptr;
@@ -3724,9 +3792,12 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
3724
3792
 
3725
3793
  if (ctx->device->coopmat2) {
3726
3794
  assert(src1_type == GGML_TYPE_F16);
3727
- return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
3795
+ return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc;
3728
3796
  }
3729
- return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
3797
+ if (ctx->device->coopmat_support) {
3798
+ return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
3799
+ }
3800
+ return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
3730
3801
  }
3731
3802
 
3732
3803
  static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) {
@@ -4449,6 +4520,8 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
4449
4520
  return aligned ? mmp->a_m : mmp->m;
4450
4521
  }
4451
4522
  return aligned ? mmp->a_l : mmp->l;
4523
+
4524
+ GGML_UNUSED(src1_type);
4452
4525
  }
4453
4526
 
4454
4527
  static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
@@ -4604,6 +4677,19 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
4604
4677
  }
4605
4678
  }
4606
4679
 
4680
+ if (src->type == to) {
4681
+ // Copy two or four bytes at a time, depending on block size.
4682
+ // For quantized types, we scale by block size/type size. But
4683
+ // this path is also used for bf16->bf16 for example, where the
4684
+ // type size must be exactly 2 or 4.
4685
+ GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
4686
+ if ((ggml_type_size(src->type) % 4) == 0) {
4687
+ return ctx->device->pipeline_contig_cpy_f32_f32;
4688
+ } else {
4689
+ return ctx->device->pipeline_contig_cpy_f16_f16;
4690
+ }
4691
+ }
4692
+
4607
4693
  std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
4608
4694
  GGML_ABORT("fatal error");
4609
4695
  }
@@ -5688,6 +5774,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
5688
5774
  }
5689
5775
  }
5690
5776
 
5777
+ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
5778
+ // Needs to be kept up to date on shader changes
5779
+ const uint32_t wg_size = scalar_flash_attention_workgroup_size;
5780
+ const uint32_t Br = scalar_flash_attention_num_large_rows;
5781
+ const uint32_t Bc = scalar_flash_attention_Bc;
5782
+
5783
+ const uint32_t acctype = f32acc ? 4 : 2;
5784
+ const uint32_t f16vec4 = 8;
5785
+
5786
+ const uint32_t tmpsh = wg_size * sizeof(float);
5787
+ const uint32_t tmpshv4 = wg_size * 4 * acctype;
5788
+
5789
+ const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
5790
+
5791
+ const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
5792
+ const uint32_t sfsh = Bc * sfshstride * acctype;
5793
+
5794
+ const uint32_t kshstride = D / 4 + 2;
5795
+ const uint32_t ksh = Bc * kshstride * f16vec4;
5796
+
5797
+ const uint32_t slope = Br * sizeof(float);
5798
+
5799
+ const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
5800
+ const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
5801
+
5802
+ VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
5803
+
5804
+ return supported;
5805
+ }
5806
+
5691
5807
  static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
5692
5808
  VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
5693
5809
  std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
@@ -5738,7 +5854,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5738
5854
  assert(q->type == GGML_TYPE_F32);
5739
5855
  assert(k->type == v->type);
5740
5856
 
5741
- bool scalar = !ctx->device->coopmat2;
5857
+ FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
5858
+ ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
5859
+
5860
+ if (path == FA_COOPMAT1) {
5861
+ const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
5862
+ (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
5863
+
5864
+ const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
5865
+
5866
+ if (!coopmat_shape_supported || !coopmat_shmem_supported) {
5867
+ path = FA_SCALAR;
5868
+ }
5869
+ }
5742
5870
 
5743
5871
  uint32_t gqa_ratio = 1;
5744
5872
  uint32_t qk_ratio = neq2 / nek2;
@@ -5746,9 +5874,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5746
5874
  uint32_t workgroups_y = (uint32_t)neq2;
5747
5875
  uint32_t workgroups_z = (uint32_t)neq3;
5748
5876
 
5749
- // For scalar FA, we can use the "large" size to accommodate qga.
5750
- // For coopmat FA, we always use the small size (which is still pretty large for gqa).
5751
- const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false);
5877
+ // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
5878
+ // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
5879
+ uint32_t max_gqa;
5880
+ switch (path) {
5881
+ case FA_SCALAR:
5882
+ case FA_COOPMAT1:
5883
+ // We may switch from coopmat1 to scalar, so use the scalar limit for both
5884
+ max_gqa = scalar_flash_attention_num_large_rows;
5885
+ break;
5886
+ case FA_COOPMAT2:
5887
+ max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
5888
+ break;
5889
+ default:
5890
+ GGML_ASSERT(0);
5891
+ }
5752
5892
 
5753
5893
  if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
5754
5894
  qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
@@ -5761,11 +5901,23 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5761
5901
  }
5762
5902
 
5763
5903
  vk_pipeline *pipelines;
5764
- // XXX TODO other backends may be changing accumulator precision to default to f32 soon
5765
- bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
5766
- bool small_rows = N <= get_fa_num_small_rows(scalar);
5904
+ bool small_rows = N <= get_fa_num_small_rows(path);
5905
+
5906
+ // coopmat1 does not actually support "small rows" (it needs 16 rows).
5907
+ // So use scalar instead.
5908
+ if (small_rows && path == FA_COOPMAT1) {
5909
+ path = FA_SCALAR;
5910
+ }
5767
5911
 
5768
- if (scalar) {
5912
+ // scalar is faster than coopmat2 when N==1
5913
+ if (N == 1 && path == FA_COOPMAT2) {
5914
+ path = FA_SCALAR;
5915
+ }
5916
+
5917
+ bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
5918
+
5919
+ switch (path) {
5920
+ case FA_SCALAR:
5769
5921
  switch (D) {
5770
5922
  case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
5771
5923
  case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
@@ -5777,7 +5929,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5777
5929
  GGML_ASSERT(!"unsupported D value");
5778
5930
  return;
5779
5931
  }
5780
- } else {
5932
+ break;
5933
+ case FA_COOPMAT1:
5934
+ switch (D) {
5935
+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
5936
+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
5937
+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
5938
+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
5939
+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
5940
+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
5941
+ default:
5942
+ GGML_ASSERT(!"unsupported D value");
5943
+ return;
5944
+ }
5945
+ break;
5946
+ case FA_COOPMAT2:
5781
5947
  switch (D) {
5782
5948
  case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
5783
5949
  case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
@@ -5789,6 +5955,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5789
5955
  GGML_ASSERT(!"unsupported D value");
5790
5956
  return;
5791
5957
  }
5958
+ break;
5959
+ default:
5960
+ GGML_ASSERT(0);
5792
5961
  }
5793
5962
  assert(pipelines);
5794
5963
 
@@ -6284,6 +6453,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
6284
6453
  case GGML_OP_ROPE:
6285
6454
  case GGML_OP_RMS_NORM:
6286
6455
  case GGML_OP_CONV_2D_DW:
6456
+ case GGML_OP_IM2COL:
6287
6457
  return true;
6288
6458
  default:
6289
6459
  return false;
@@ -6582,7 +6752,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6582
6752
  case GGML_OP_UNARY:
6583
6753
  case GGML_OP_CONV_2D_DW:
6584
6754
  {
6585
- const uint32_t ne = ggml_nelements(dst);
6755
+ uint32_t ne = ggml_nelements(dst);
6756
+ if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
6757
+ // Convert from number of logical elements to 2- or 4-byte units.
6758
+ ne /= ggml_blck_size(src0->type);
6759
+ if ((ggml_type_size(src0->type) % 4) == 0) {
6760
+ ne *= ggml_type_size(src0->type) / 4;
6761
+ } else {
6762
+ ne *= ggml_type_size(src0->type) / 2;
6763
+ }
6764
+ }
6586
6765
  if (ne > 262144) {
6587
6766
  elements = { 512, 512, CEIL_DIV(ne, 262144) };
6588
6767
  } else if (ne > 512) {
@@ -7132,8 +7311,19 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
7132
7311
  const uint32_t src0_type_size = ggml_type_size(src0->type);
7133
7312
  const uint32_t dst_type_size = ggml_type_size(dst->type);
7134
7313
 
7314
+ uint32_t ne = (uint32_t)ggml_nelements(src0);
7315
+ if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
7316
+ // Convert from number of logical elements to 2- or 4-byte units.
7317
+ ne /= ggml_blck_size(src0->type);
7318
+ if ((ggml_type_size(src0->type) % 4) == 0) {
7319
+ ne *= ggml_type_size(src0->type) / 4;
7320
+ } else {
7321
+ ne *= ggml_type_size(src0->type) / 2;
7322
+ }
7323
+ }
7324
+
7135
7325
  ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
7136
- (uint32_t)ggml_nelements(src0),
7326
+ ne,
7137
7327
  (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7138
7328
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7139
7329
  0,
@@ -8696,7 +8886,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8696
8886
 
8697
8887
  ctx->tensor_ctxs[node_idx] = compute_ctx;
8698
8888
 
8699
- #if defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_PERF)
8889
+ #if defined(GGML_VULKAN_CHECK_RESULTS)
8700
8890
  // Force context reset on each node so that each tensor ends up in its own context
8701
8891
  // and can be run and compared to its CPU equivalent separately
8702
8892
  last_node = true;
@@ -9115,8 +9305,7 @@ static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_
9115
9305
  try {
9116
9306
  ptr = ggml_vk_host_malloc(vk_instance.devices[0], size);
9117
9307
  } catch (vk::SystemError& e) {
9118
- std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl;
9119
- std::cerr << "ggml_vulkan: " << e.what() << std::endl;
9308
+ GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what());
9120
9309
  // fallback to cpu buffer
9121
9310
  return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
9122
9311
  }
@@ -9317,6 +9506,29 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9317
9506
  bool first_node_in_batch = true; // true if next node will be first node in a batch
9318
9507
  int submit_node_idx = 0; // index to first node in a batch
9319
9508
 
9509
+ vk_context compute_ctx;
9510
+ if (vk_perf_logger_enabled) {
9511
+ // allocate/resize the query pool
9512
+ if (ctx->device->num_queries < cgraph->n_nodes + 1) {
9513
+ if (ctx->device->query_pool) {
9514
+ ctx->device->device.destroyQueryPool(ctx->device->query_pool);
9515
+ }
9516
+ VkQueryPoolCreateInfo query_create_info = { VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO };
9517
+ query_create_info.queryType = VK_QUERY_TYPE_TIMESTAMP;
9518
+ query_create_info.queryCount = cgraph->n_nodes + 100;
9519
+ ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info);
9520
+ ctx->device->num_queries = query_create_info.queryCount;
9521
+ }
9522
+
9523
+ ctx->device->device.resetQueryPool(ctx->device->query_pool, 0, cgraph->n_nodes+1);
9524
+
9525
+ GGML_ASSERT(ctx->compute_ctx.expired());
9526
+ compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
9527
+ ctx->compute_ctx = compute_ctx;
9528
+ ggml_vk_ctx_begin(ctx->device, compute_ctx);
9529
+ compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
9530
+ }
9531
+
9320
9532
  // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
9321
9533
  // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
9322
9534
  // (and scaled down based on model size, so smaller models submit earlier).
@@ -9344,6 +9556,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9344
9556
 
9345
9557
  bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
9346
9558
 
9559
+ if (vk_perf_logger_enabled) {
9560
+ if (ctx->compute_ctx.expired()) {
9561
+ compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
9562
+ ctx->compute_ctx = compute_ctx;
9563
+ ggml_vk_ctx_begin(ctx->device, compute_ctx);
9564
+ } else {
9565
+ compute_ctx = ctx->compute_ctx.lock();
9566
+ }
9567
+ compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
9568
+ }
9569
+
9347
9570
  if (enqueued) {
9348
9571
  ++submitted_nodes;
9349
9572
 
@@ -9365,9 +9588,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9365
9588
  }
9366
9589
  }
9367
9590
 
9368
- #ifdef GGML_VULKAN_PERF
9369
- ctx->device->perf_logger->print_timings();
9370
- #endif
9591
+ if (vk_perf_logger_enabled) {
9592
+ // End the command buffer and submit/wait
9593
+ GGML_ASSERT(!ctx->compute_ctx.expired());
9594
+ compute_ctx = ctx->compute_ctx.lock();
9595
+ ggml_vk_ctx_end(compute_ctx);
9596
+
9597
+ ggml_vk_submit(compute_ctx, ctx->device->fence);
9598
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences");
9599
+ ctx->device->device.resetFences({ ctx->device->fence });
9600
+
9601
+ // Get the results and pass them to the logger
9602
+ std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
9603
+ ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait);
9604
+ for (int i = 0; i < cgraph->n_nodes; i++) {
9605
+ if (!ggml_vk_is_empty(cgraph->nodes[i])) {
9606
+ ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod));
9607
+ }
9608
+ }
9609
+
9610
+ ctx->device->perf_logger->print_timings();
9611
+ }
9371
9612
 
9372
9613
  ggml_vk_graph_cleanup(ctx);
9373
9614
 
@@ -9718,6 +9959,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9718
9959
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
9719
9960
  return true;
9720
9961
  }
9962
+
9963
+ // We can handle copying from a type to the same type if it's
9964
+ // contiguous (memcpy). We use f16 or f32 shaders to do the copy,
9965
+ // so the type/block size must be a multiple of 4.
9966
+ if (src0_type == src1_type &&
9967
+ ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) &&
9968
+ (ggml_type_size(src0_type) % 2) == 0) {
9969
+ return true;
9970
+ }
9721
9971
  return false;
9722
9972
  } break;
9723
9973
  case GGML_OP_REPEAT:
@@ -10123,7 +10373,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10123
10373
  } else if (tensor->op == GGML_OP_CONCAT) {
10124
10374
  tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
10125
10375
  } else if (tensor->op == GGML_OP_UPSCALE) {
10126
- tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]);
10376
+ tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
10127
10377
  } else if (tensor->op == GGML_OP_SCALE) {
10128
10378
  const float * params = (const float *)tensor->op_params;
10129
10379
  tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
@@ -10412,7 +10662,8 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
10412
10662
  ggml_vk_print_graph_origin(tensor, done);
10413
10663
  GGML_ABORT("fatal error");
10414
10664
  }
10415
- if (first_error[0] == -1 && std::fabs(correct - result) > 0.1f) {
10665
+ const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f;
10666
+ if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) {
10416
10667
  first_error[0] = i0;
10417
10668
  first_error[1] = i1;
10418
10669
  first_error[2] = i2;
@@ -10424,7 +10675,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
10424
10675
  // Special case, value is infinite, avoid NaN result in avg_err
10425
10676
  // NaN also appears in results, if both are nan error is 0
10426
10677
  if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) {
10427
- avg_err += std::fabs(correct - result);
10678
+ avg_err += std::fabs(correct - result) / denom;
10428
10679
  }
10429
10680
  counter++;
10430
10681
  }
@@ -10459,7 +10710,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
10459
10710
  ggml_vk_print_graph_origin(tensor, done);
10460
10711
  }
10461
10712
 
10462
- if (avg_err > 0.05 || std::isnan(avg_err)) {
10713
+ if (avg_err > 0.5 || std::isnan(avg_err)) {
10463
10714
  std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
10464
10715
  std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
10465
10716
  if (src0 != nullptr) {