whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
11
11
 
12
12
  #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
13
13
  #define MMQ_ITER_K 256
14
+ #define MMQ_ITER_K_MXFP4_FP4 512
14
15
  #define MMQ_NWARPS 8
15
16
 
16
17
  typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
@@ -44,8 +45,15 @@ struct block_q8_1_mmq {
44
45
  };
45
46
  int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
46
47
  };
48
+
49
+ struct block_fp4_mmq {
50
+ uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
51
+ int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
52
+ };
53
+
47
54
  static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
48
55
  static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
56
+ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
49
57
 
50
58
  static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
51
59
  switch (type_x) {
@@ -92,7 +100,7 @@ struct tile_x_sizes {
92
100
  };
93
101
 
94
102
  static int get_mmq_x_max_host(const int cc) {
95
- return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 :
103
+ return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
96
104
  GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
97
105
  #ifdef GGML_CUDA_FORCE_MMQ
98
106
  128 : 64;
@@ -102,7 +110,7 @@ static int get_mmq_x_max_host(const int cc) {
102
110
  }
103
111
 
104
112
  static constexpr __device__ int get_mmq_x_max_device() {
105
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
113
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
106
114
  return 128;
107
115
  #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
108
116
 
@@ -121,7 +129,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
121
129
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
122
130
 
123
131
  #endif // defined(GGML_USE_HIP)
124
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
132
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
125
133
  }
126
134
 
127
135
  static int get_mmq_y_host(const int cc) {
@@ -129,6 +137,14 @@ static int get_mmq_y_host(const int cc) {
129
137
  ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
130
138
  }
131
139
 
140
+ static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
141
+ #if defined(BLACKWELL_MMA_AVAILABLE)
142
+ return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
143
+ #else
144
+ return MMQ_ITER_K;
145
+ #endif // defined(BLACKWELL_MMA_AVAILABLE)
146
+ }
147
+
132
148
  static constexpr __device__ int get_mmq_y_device() {
133
149
  #if defined(GGML_USE_HIP)
134
150
  #if defined(RDNA1)
@@ -191,6 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
191
207
  }
192
208
 
193
209
  #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
210
+ #define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
194
211
  #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
195
212
  #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
196
213
  #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
@@ -201,6 +218,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
201
218
  static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
202
219
  static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
203
220
  static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
221
+ static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
222
+ static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
204
223
 
205
224
  static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
206
225
  switch (type) {
@@ -209,6 +228,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
209
228
  case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
210
229
  case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
211
230
  case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
231
+ // tile sizes are the same for Q8_1 and FP4 for blackwell
212
232
  case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
213
233
  case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
214
234
  case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
@@ -228,10 +248,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
228
248
  }
229
249
 
230
250
  // block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
231
- #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
251
+ #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
252
+ #define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
232
253
 
233
254
  static int mmq_get_granularity_host(const int mmq_x, const int cc) {
234
- if (amd_mfma_available(cc)) {
255
+ if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
235
256
  return mmq_x >= 128 ? 32 : 16;
236
257
  } else if (turing_mma_available(cc) && mmq_x >= 48) {
237
258
  return 16;
@@ -240,7 +261,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) {
240
261
  }
241
262
  }
242
263
 
243
- #if defined(AMD_MFMA_AVAILABLE)
264
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
244
265
  static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
245
266
  return mmq_x >= 128 ? 32 : 16;
246
267
  }
@@ -265,7 +286,7 @@ static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
265
286
  #endif // (GGML_USE_HIP)
266
287
 
267
288
  static constexpr __device__ int mmq_get_nwarps_device() {
268
- #if defined(AMD_MFMA_AVAILABLE)
289
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
269
290
  return 8;
270
291
  #else
271
292
  return 256/ggml_cuda_get_physical_warp_size();
@@ -279,14 +300,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
279
300
  constexpr int nwarps = mmq_get_nwarps_device();
280
301
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
281
302
 
282
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
303
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
283
304
  int * x_qs = (int *) x_tile;
284
305
  float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
285
306
  #else
286
307
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
287
308
  int * x_qs = (int *) x_tile;
288
309
  float * x_df = (float *) (x_qs + txs.qs);
289
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
310
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
290
311
 
291
312
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
292
313
  constexpr int nrows = warp_size / threads_per_row;
@@ -305,7 +326,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
305
326
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
306
327
  const int qs0 = get_int_b2(bxi->qs, kqsx);
307
328
 
308
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
329
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
309
330
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
310
331
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
311
332
  #else
@@ -327,11 +348,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
327
348
 
328
349
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
329
350
 
330
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
351
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
331
352
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
332
353
  #else
333
354
  x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
334
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
355
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
335
356
  }
336
357
  }
337
358
 
@@ -382,14 +403,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
382
403
  constexpr int nwarps = mmq_get_nwarps_device();
383
404
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
384
405
 
385
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
406
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
386
407
  int * x_qs = (int *) x_tile;
387
408
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
388
409
  #else
389
410
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
390
411
  int * x_qs = (int *) x_tile;
391
412
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
392
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
413
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
393
414
 
394
415
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
395
416
  constexpr int nrows = warp_size / threads_per_row;
@@ -408,12 +429,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
408
429
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
409
430
  const int qs0 = get_int_b4(bxi->qs, kqsx);
410
431
 
411
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
432
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
412
433
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
413
434
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
414
435
  #else
415
436
  x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
416
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
437
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
417
438
  }
418
439
 
419
440
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
@@ -430,11 +451,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
430
451
 
431
452
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
432
453
 
433
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
454
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
434
455
  x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
435
456
  #else
436
457
  x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
437
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
458
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
438
459
  }
439
460
  }
440
461
 
@@ -485,14 +506,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
485
506
  constexpr int nwarps = mmq_get_nwarps_device();
486
507
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
487
508
 
488
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
509
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
489
510
  int * x_qs = (int *) x_tile;
490
511
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
491
512
  #else
492
513
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
493
514
  int * x_qs = (int *) x_tile;
494
515
  float * x_df = (float *) (x_qs + txs.qs);
495
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
516
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
496
517
 
497
518
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
498
519
  constexpr int nrows = warp_size / threads_per_row;
@@ -527,13 +548,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
527
548
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
528
549
  qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
529
550
 
530
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
551
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
531
552
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
532
553
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
533
554
  #else
534
555
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
535
556
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
536
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
557
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
537
558
  }
538
559
 
539
560
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
@@ -550,11 +571,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
550
571
 
551
572
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
552
573
 
553
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
574
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
554
575
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
555
576
  #else
556
577
  x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
557
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
578
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
558
579
  }
559
580
  }
560
581
 
@@ -563,14 +584,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
563
584
  constexpr int nwarps = mmq_get_nwarps_device();
564
585
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
565
586
 
566
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
587
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
567
588
  int * x_qs = (int *) x_tile;
568
589
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
569
590
  #else
570
591
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
571
592
  int * x_qs = (int *) x_tile;
572
593
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
573
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
594
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
574
595
 
575
596
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
576
597
  constexpr int nrows = warp_size / threads_per_row;
@@ -603,13 +624,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
603
624
  qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
604
625
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
605
626
 
606
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
627
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
607
628
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
608
629
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
609
630
  #else
610
631
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
611
632
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
612
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
633
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
613
634
  }
614
635
 
615
636
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
@@ -626,11 +647,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
626
647
 
627
648
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
628
649
 
629
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
650
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
630
651
  x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
631
652
  #else
632
653
  x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
633
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
654
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
634
655
  }
635
656
  }
636
657
 
@@ -639,14 +660,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
639
660
  constexpr int nwarps = mmq_get_nwarps_device();
640
661
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
641
662
 
642
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
663
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
643
664
  int * x_qs = (int *) x_tile;
644
665
  float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
645
666
  #else
646
667
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
647
668
  int * x_qs = (int *) x_tile;
648
669
  float * x_df = (float *) (x_qs + txs.qs);
649
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
670
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
650
671
 
651
672
  // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
652
673
  constexpr int threads_per_row = 32;
@@ -665,13 +686,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
665
686
 
666
687
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
667
688
 
668
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
689
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
669
690
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
670
691
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
671
692
  #else
672
693
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
673
694
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
674
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
695
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
675
696
  }
676
697
 
677
698
  constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
@@ -688,11 +709,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
688
709
 
689
710
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
690
711
 
691
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
712
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
692
713
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
693
714
  #else
694
715
  x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
695
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
716
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
696
717
  }
697
718
  }
698
719
 
@@ -701,14 +722,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
701
722
  constexpr int nwarps = mmq_get_nwarps_device();
702
723
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
703
724
 
704
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
725
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
705
726
  int * x_qs = (int *) x_tile;
706
727
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
707
728
  #else
708
729
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
709
730
  int * x_qs = (int *) x_tile;
710
731
  float * x_df = (float *) (x_qs + txs.qs);
711
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
732
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
712
733
 
713
734
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
714
735
  constexpr int nrows = warp_size / threads_per_row;
@@ -730,13 +751,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
730
751
  const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
731
752
  const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
732
753
 
733
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
754
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
734
755
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
735
756
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
736
757
  #else
737
758
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
738
759
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
739
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
760
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
740
761
  }
741
762
 
742
763
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
@@ -753,11 +774,55 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
753
774
 
754
775
  const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
755
776
 
756
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
777
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
757
778
  x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
758
779
  #else
759
780
  x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
760
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
781
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
782
+ }
783
+ }
784
+
785
+ template <int mmq_y, bool need_check>
786
+ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
787
+ int * __restrict__ x_tile,
788
+ const int kbx0,
789
+ const int i_max,
790
+ const int stride) {
791
+ constexpr int nwarps = mmq_get_nwarps_device();
792
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
793
+
794
+ int * x_qs = (int *) x_tile;
795
+ uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
796
+
797
+ const int txi = threadIdx.x;
798
+
799
+ constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
800
+
801
+ constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block
802
+ constexpr int rows_per_warp = warp_size / threads_per_row;
803
+ const int kbx = txi % threads_per_row;
804
+ const int row_in_warp = txi / threads_per_row;
805
+
806
+ #pragma unroll
807
+ for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
808
+ int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
809
+
810
+ if constexpr (need_check) {
811
+ i = min(i, i_max);
812
+ }
813
+
814
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
815
+
816
+ // quantize_mxfp4_mmq permutes nibbles to match the quantized format
817
+ const int k0 = kbx * 4;
818
+ memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
819
+
820
+ // Load E8M0 scales: pack 2 consecutive scales into one uint32
821
+ if (kbx % 2 == 0) {
822
+ uint32_t e = bxi->e;
823
+ e |= ((bxi + 1)->e << 8);
824
+ x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
825
+ }
761
826
  }
762
827
  }
763
828
 
@@ -796,10 +861,11 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
796
861
  template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
797
862
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
798
863
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
799
- #if defined(AMD_MFMA_AVAILABLE)
800
- typedef tile<16, 8, int> tile_A;
801
- typedef tile<16, 8, int> tile_B;
802
- typedef tile<16, 16, int> tile_C;
864
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
865
+ constexpr data_layout input_layout = get_input_data_layout();
866
+ typedef tile<16, 8, int, input_layout> tile_A;
867
+ typedef tile<16, 8, int, input_layout> tile_B;
868
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
803
869
 
804
870
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
805
871
  constexpr int rows_per_warp = granularity;
@@ -927,7 +993,79 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
927
993
  }
928
994
  }
929
995
  }
930
- #endif // defined(AMD_MFMA_AVAILABLE)
996
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
997
+ }
998
+
999
+ template <int mmq_x, int mmq_y>
1000
+ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
1001
+ const int * __restrict__ y,
1002
+ float * __restrict__ sum,
1003
+ const int k00) {
1004
+ typedef tile<16, 8, int> tile_A;
1005
+ typedef tile<8, 8, int> tile_B;
1006
+ typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
1007
+
1008
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1009
+ constexpr int rows_per_warp = 2 * granularity;
1010
+ constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
1011
+
1012
+ y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
1013
+
1014
+ // Match layout from load_tiles_mxfp4_fp4
1015
+ const int * x_qs = (const int *) x;
1016
+ const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
1017
+ const int * y_qs = (const int *) y + 4;
1018
+ const uint32_t * y_sc = (const uint32_t *) y;
1019
+
1020
+ // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
1021
+ tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1022
+ uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1023
+
1024
+ // Block scale
1025
+ // Each thread has to point to a 4 byte scale value
1026
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1027
+
1028
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1029
+
1030
+ #pragma unroll
1031
+ for (int n = 0; n < ntx; ++n) {
1032
+ #pragma unroll
1033
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1034
+ const int k0 = k00 + k01;
1035
+
1036
+ load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
1037
+ MMQ_MMA_TILE_X_K_FP4);
1038
+
1039
+ // based on block-scaling document, 2 threads in each quad need to supply to the scale value
1040
+ const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
1041
+ scaleA[n][k01 / (2 * QI_MXFP4)] =
1042
+ *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
1043
+ }
1044
+ }
1045
+
1046
+ #pragma unroll
1047
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
1048
+ #pragma unroll
1049
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1050
+ tile_B B;
1051
+ uint32_t scaleB; // 2xN scales
1052
+
1053
+ load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
1054
+
1055
+ scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
1056
+
1057
+ #pragma unroll
1058
+ for (int n = 0; n < ntx; ++n) {
1059
+ tile_C C;
1060
+
1061
+ mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
1062
+ #pragma unroll
1063
+ for (int l = 0; l < tile_C::ne; ++l) {
1064
+ sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
1065
+ }
1066
+ }
1067
+ }
1068
+ }
931
1069
  }
932
1070
 
933
1071
  template <int mmq_x, int mmq_y>
@@ -965,10 +1103,11 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
965
1103
  template <int mmq_x, int mmq_y>
966
1104
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
967
1105
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
968
- #if defined(AMD_MFMA_AVAILABLE)
969
- typedef tile<16, 8, int> tile_A;
970
- typedef tile<16, 8, int> tile_B;
971
- typedef tile<16, 16, int> tile_C;
1106
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1107
+ constexpr data_layout input_layout = get_input_data_layout();
1108
+ typedef tile<16, 8, int, input_layout> tile_A;
1109
+ typedef tile<16, 8, int, input_layout> tile_B;
1110
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
972
1111
 
973
1112
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
974
1113
  constexpr int rows_per_warp = granularity;
@@ -1087,7 +1226,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
1087
1226
  }
1088
1227
  }
1089
1228
  }
1090
- #endif // defined(AMD_MFMA_AVAILABLE)
1229
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1091
1230
  }
1092
1231
 
1093
1232
  // Used for Q3_K, IQ2_S, and IQ2_XS
@@ -1130,10 +1269,11 @@ template <int mmq_x, int mmq_y>
1130
1269
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1131
1270
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1132
1271
  #if defined(AMD_MFMA_AVAILABLE)
1133
- typedef tile<16, 8, int> tile_A;
1134
- typedef tile<16, 8, int> tile_B;
1135
- typedef tile<16, 16, int> tile_C;
1136
- typedef tile<64, 2, int> tile_load;
1272
+ constexpr data_layout input_layout = get_input_data_layout();
1273
+ typedef tile<16, 8, int, input_layout> tile_A;
1274
+ typedef tile<16, 8, int, input_layout> tile_B;
1275
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1276
+ typedef tile<64, 2, int, input_layout> tile_load;
1137
1277
 
1138
1278
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
1139
1279
  constexpr int rows_per_warp = granularity;
@@ -1170,6 +1310,55 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1170
1310
  tile_C C;
1171
1311
  mma(C, A[n], B[0]);
1172
1312
 
1313
+ #pragma unroll
1314
+ for (int l = 0; l < tile_C::ne; ++l) {
1315
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1316
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
1317
+ }
1318
+ }
1319
+ }
1320
+ }
1321
+ #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1322
+ constexpr data_layout input_layout = get_input_data_layout();
1323
+ typedef tile<16, 4, int, input_layout> tile_A;
1324
+ typedef tile<16, 4, int, input_layout> tile_B;
1325
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1326
+
1327
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1328
+ constexpr int rows_per_warp = granularity;
1329
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1330
+
1331
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1332
+
1333
+ const int * x_qs = (const int *) x;
1334
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1335
+ const int * y_qs = (const int *) y + 4;
1336
+ const float * y_df = (const float *) y;
1337
+
1338
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1339
+
1340
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1341
+ const int k0 = k00 + k01;
1342
+
1343
+ tile_A A[ntx];
1344
+ #pragma unroll
1345
+ for (int n = 0; n < ntx; ++n) {
1346
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1347
+ }
1348
+
1349
+ #pragma unroll
1350
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1351
+ tile_B B;
1352
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1353
+
1354
+ const int j = j0 + tile_C::get_j(0);
1355
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
1356
+
1357
+ #pragma unroll
1358
+ for (int n = 0; n < ntx; ++n) {
1359
+ tile_C C;
1360
+ mma(C, A[n], B);
1361
+
1173
1362
  #pragma unroll
1174
1363
  for (int l = 0; l < tile_C::ne; ++l) {
1175
1364
  const int i = i0 + n*tile_C::I + tile_C::get_i(l);
@@ -1257,21 +1446,21 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1257
1446
  #else
1258
1447
  GGML_UNUSED_VARS(x, y, sum, k00);
1259
1448
  NO_DEVICE_CODE;
1260
- #endif // AMD_MFMA_AVAILABLE
1449
+ #endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
1261
1450
  }
1262
1451
 
1263
1452
  template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
1264
1453
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1265
1454
  constexpr int nwarps = mmq_get_nwarps_device();
1266
1455
 
1267
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1456
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1268
1457
  int * x_qs = (int *) x_tile;
1269
1458
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1270
1459
  #else
1271
1460
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1272
1461
  int * x_qs = (int *) x_tile;
1273
1462
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1274
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1463
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1275
1464
 
1276
1465
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
1277
1466
  constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
@@ -1295,11 +1484,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1295
1484
 
1296
1485
  const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
1297
1486
 
1298
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1487
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1299
1488
  x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
1300
1489
  #else
1301
1490
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1302
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1491
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1303
1492
  }
1304
1493
 
1305
1494
  const int sc_m = bxi->scales[kqsx];
@@ -1310,11 +1499,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1310
1499
  const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
1311
1500
  #endif // FAST_FP16_AVAILABLE
1312
1501
 
1313
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1502
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1314
1503
  x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
1315
1504
  #else
1316
1505
  x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
1317
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1506
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1318
1507
  }
1319
1508
  }
1320
1509
 
@@ -1387,10 +1576,11 @@ template <int mmq_x, int mmq_y>
1387
1576
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1388
1577
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1389
1578
  #if defined(AMD_MFMA_AVAILABLE)
1390
- typedef tile<16, 8, int> tile_A;
1391
- typedef tile<16, 8, int> tile_B;
1392
- typedef tile<16, 16, int> tile_C;
1393
- typedef tile<64, 2, int> tile_load;
1579
+ constexpr data_layout input_layout = get_input_data_layout();
1580
+ typedef tile<16, 8, int, input_layout> tile_A;
1581
+ typedef tile<16, 8, int, input_layout> tile_B;
1582
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1583
+ typedef tile<64, 2, int, input_layout> tile_load;
1394
1584
 
1395
1585
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
1396
1586
  constexpr int rows_per_warp = granularity;
@@ -1438,6 +1628,74 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1438
1628
  tile_C Cd;
1439
1629
  mma(Cd, A[n], B[0]);
1440
1630
 
1631
+ #pragma unroll
1632
+ for (int l = 0; l < tile_C::ne; ++l) {
1633
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1634
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
1635
+ float tmp = Cd.x[l]*dm.x;
1636
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1637
+ tmp -= Cm.x[l]*dm.y;
1638
+ }
1639
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
1640
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
1641
+ }
1642
+ }
1643
+ }
1644
+ }
1645
+ #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1646
+ constexpr data_layout input_layout = get_input_data_layout();
1647
+ typedef tile<16, 4, int, input_layout> tile_A;
1648
+ typedef tile<16, 4, int, input_layout> tile_B;
1649
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1650
+
1651
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1652
+ constexpr int rows_per_warp = granularity;
1653
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1654
+
1655
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1656
+
1657
+ const int * x_qs = (const int *) x;
1658
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1659
+ const int * y_qs = (const int *) y + 4;
1660
+ const half2 * y_ds = (const half2 *) y;
1661
+
1662
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1663
+
1664
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1665
+ const int k0 = k00 + k01;
1666
+
1667
+ tile_A A[ntx];
1668
+ #pragma unroll
1669
+ for (int n = 0; n < ntx; ++n) {
1670
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1671
+ }
1672
+
1673
+ #pragma unroll
1674
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1675
+ tile_B B;
1676
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1677
+
1678
+ const int j = j0 + tile_C::get_j(0);
1679
+ const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
1680
+ const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
1681
+ : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1682
+ : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
1683
+
1684
+ tile_C Cm;
1685
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1686
+ tile_A A1;
1687
+ #pragma unroll
1688
+ for (int l = 0; l < tile_A::ne; ++l) {
1689
+ A1.x[l] = 0x01010101;
1690
+ }
1691
+ mma(Cm, A1, B);
1692
+ }
1693
+
1694
+ #pragma unroll
1695
+ for (int n = 0; n < ntx; ++n) {
1696
+ tile_C Cd;
1697
+ mma(Cd, A[n], B);
1698
+
1441
1699
  #pragma unroll
1442
1700
  for (int l = 0; l < tile_C::ne; ++l) {
1443
1701
  const int i = i0 + n*tile_C::I + tile_C::get_i(l);
@@ -1574,7 +1832,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1574
1832
  #else
1575
1833
  GGML_UNUSED_VARS(x, y, sum, k00);
1576
1834
  NO_DEVICE_CODE;
1577
- #endif // AMD_MFMA_AVAILABLE
1835
+ #endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
1578
1836
  }
1579
1837
 
1580
1838
  template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
@@ -1582,7 +1840,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1582
1840
  constexpr int nwarps = mmq_get_nwarps_device();
1583
1841
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1584
1842
 
1585
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1843
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1586
1844
  int * x_qs = (int *) x_tile;
1587
1845
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1588
1846
  #else
@@ -1618,11 +1876,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1618
1876
 
1619
1877
  const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
1620
1878
 
1621
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1879
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1622
1880
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
1623
1881
  #else
1624
1882
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1625
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1883
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1626
1884
  }
1627
1885
  }
1628
1886
 
@@ -1649,7 +1907,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1649
1907
 
1650
1908
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1651
1909
 
1652
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1910
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1653
1911
  const int8_t * sc8 = (const int8_t *) &sc;
1654
1912
  const float d = bxi->d;
1655
1913
 
@@ -1659,10 +1917,10 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1659
1917
  }
1660
1918
  #else
1661
1919
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
1662
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1920
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1663
1921
  }
1664
1922
 
1665
- #if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1923
+ #if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE))
1666
1924
  #pragma unroll
1667
1925
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1668
1926
  int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
@@ -1675,7 +1933,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1675
1933
 
1676
1934
  x_df[i] = bxi->d;
1677
1935
  }
1678
- #endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1936
+ #endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE)
1679
1937
  }
1680
1938
 
1681
1939
  template <int mmq_x, int mmq_y>
@@ -1728,7 +1986,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1728
1986
  constexpr int nwarps = mmq_get_nwarps_device();
1729
1987
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1730
1988
 
1731
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1989
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1732
1990
  int * x_qs = (int *) x_tile;
1733
1991
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1734
1992
  #else
@@ -1736,7 +1994,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1736
1994
  int * x_qs = (int *) x_tile;
1737
1995
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1738
1996
  int * x_sc = (int *) (x_dm + txs.dm);
1739
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1997
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1740
1998
 
1741
1999
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
1742
2000
  constexpr int nrows = warp_size / threads_per_row;
@@ -1753,19 +2011,19 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1753
2011
  const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1754
2012
  const int qs0 = get_int_b4(bxi->qs, txi);
1755
2013
 
1756
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2014
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1757
2015
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
1758
2016
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
1759
2017
  #else
1760
2018
  x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
1761
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2019
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1762
2020
  }
1763
2021
 
1764
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2022
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1765
2023
  constexpr int rows_per_warp = warp_size / 2;
1766
2024
  #pragma unroll
1767
2025
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1768
- #if defined(AMD_MFMA_AVAILABLE)
2026
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1769
2027
  // Need if on AMD instead of % because warp_size == 64
1770
2028
  // This causes double work and throughput loss (MI300X)
1771
2029
  // H100 loses about 100 t/s with 'if' condition over '%'
@@ -1774,7 +2032,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1774
2032
  #else
1775
2033
  int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
1776
2034
  {
1777
- #endif // defined(AMD_MFMA_AVAILABLE)
2035
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1778
2036
  if (need_check) {
1779
2037
  i = min(i, i_max);
1780
2038
  }
@@ -1829,7 +2087,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1829
2087
 
1830
2088
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1831
2089
  }
1832
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2090
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1833
2091
  }
1834
2092
 
1835
2093
  template <int mmq_x, int mmq_y>
@@ -1872,7 +2130,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1872
2130
  constexpr int nwarps = mmq_get_nwarps_device();
1873
2131
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1874
2132
 
1875
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2133
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1876
2134
  int * x_qs = (int *) x_tile;
1877
2135
  half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
1878
2136
  #else
@@ -1908,16 +2166,16 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1908
2166
  const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
1909
2167
  const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
1910
2168
 
1911
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2169
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1912
2170
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
1913
2171
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
1914
2172
  #else
1915
2173
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
1916
2174
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
1917
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2175
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1918
2176
  }
1919
2177
 
1920
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2178
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1921
2179
  constexpr int rows_per_warp = warp_size / 2;
1922
2180
  #pragma unroll
1923
2181
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
@@ -1930,7 +2188,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1930
2188
  #else
1931
2189
  int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
1932
2190
  {
1933
- #endif // defined(AMD_MFMA_AVAILABLE)
2191
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1934
2192
  if (need_check) {
1935
2193
  i = min(i, i_max);
1936
2194
  }
@@ -1986,7 +2244,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1986
2244
 
1987
2245
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1988
2246
  }
1989
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2247
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1990
2248
  }
1991
2249
 
1992
2250
  template <int mmq_x, int mmq_y>
@@ -2029,7 +2287,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2029
2287
  constexpr int nwarps = mmq_get_nwarps_device();
2030
2288
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2031
2289
 
2032
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2290
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2033
2291
  int * x_qs = (int *) x_tile;
2034
2292
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2035
2293
  int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
@@ -2038,7 +2296,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2038
2296
  int * x_qs = (int *) x_tile;
2039
2297
  float * x_df = (float *) (x_qs + txs.qs);
2040
2298
  int * x_sc = (int *) (x_df + txs.dm);
2041
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2299
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2042
2300
 
2043
2301
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
2044
2302
  constexpr int nrows = warp_size / threads_per_row;
@@ -2065,13 +2323,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2065
2323
  const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
2066
2324
  const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
2067
2325
 
2068
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2326
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2069
2327
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2070
2328
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2071
2329
  #else
2072
2330
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2073
2331
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2074
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2332
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2075
2333
  }
2076
2334
 
2077
2335
  #pragma unroll
@@ -2084,11 +2342,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2084
2342
 
2085
2343
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
2086
2344
 
2087
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2345
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2088
2346
  x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
2089
2347
  #else
2090
2348
  x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
2091
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2349
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2092
2350
  }
2093
2351
 
2094
2352
  constexpr int rows_per_warp = warp_size / 4;
@@ -2102,11 +2360,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2102
2360
 
2103
2361
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
2104
2362
 
2105
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2363
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2106
2364
  x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
2107
2365
  #else
2108
2366
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
2109
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2367
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2110
2368
  }
2111
2369
  }
2112
2370
 
@@ -2149,10 +2407,11 @@ template <int mmq_x, int mmq_y>
2149
2407
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2150
2408
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2151
2409
  #if defined(AMD_MFMA_AVAILABLE)
2152
- typedef tile<16, 8, int> tile_A;
2153
- typedef tile<16, 8, int> tile_B;
2154
- typedef tile<16, 16, int> tile_C;
2155
- typedef tile<64, 2, int> tile_load;
2410
+ constexpr data_layout input_layout = get_input_data_layout();
2411
+ typedef tile<16, 8, int, input_layout> tile_A;
2412
+ typedef tile<16, 8, int, input_layout> tile_B;
2413
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
2414
+ typedef tile<64, 2, int, input_layout> tile_load;
2156
2415
 
2157
2416
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
2158
2417
  constexpr int rows_per_warp = granularity;
@@ -2190,6 +2449,57 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2190
2449
  tile_C C;
2191
2450
  mma(C, A[n], B[0]);
2192
2451
 
2452
+ #pragma unroll
2453
+ for (int l = 0; l < tile_C::ne; ++l) {
2454
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2455
+ const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2456
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2457
+ }
2458
+ }
2459
+ }
2460
+ }
2461
+ #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
2462
+ constexpr data_layout input_layout = get_input_data_layout();
2463
+ typedef tile<16, 4, int, input_layout> tile_A;
2464
+ typedef tile<16, 4, int, input_layout> tile_B;
2465
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
2466
+
2467
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
2468
+ constexpr int rows_per_warp = granularity;
2469
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2470
+
2471
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2472
+
2473
+ const int * x_qs = (const int *) x;
2474
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2475
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2476
+ const int * y_qs = (const int *) y + 4;
2477
+ const float * y_df = (const float *) y;
2478
+
2479
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
2480
+
2481
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2482
+ const int k0 = k00 + k01;
2483
+
2484
+ tile_A A[ntx];
2485
+ #pragma unroll
2486
+ for (int n = 0; n < ntx; ++n) {
2487
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2488
+ }
2489
+
2490
+ #pragma unroll
2491
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2492
+ tile_B B;
2493
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2494
+
2495
+ const int j = j0 + tile_C::get_j(0);
2496
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
2497
+
2498
+ #pragma unroll
2499
+ for (int n = 0; n < ntx; ++n) {
2500
+ tile_C C;
2501
+ mma(C, A[n], B);
2502
+
2193
2503
  #pragma unroll
2194
2504
  for (int l = 0; l < tile_C::ne; ++l) {
2195
2505
  const int i = i0 + n*tile_C::I + tile_C::get_i(l);
@@ -2303,7 +2613,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2303
2613
  #else
2304
2614
  GGML_UNUSED_VARS(x, y, sum, k00);
2305
2615
  NO_DEVICE_CODE;
2306
- #endif // AMD_MFMA_AVAILABLE
2616
+ #endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
2307
2617
  }
2308
2618
 
2309
2619
  template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
@@ -2311,14 +2621,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2311
2621
  constexpr int nwarps = mmq_get_nwarps_device();
2312
2622
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2313
2623
 
2314
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2624
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2315
2625
  int * x_qs = (int *) x_tile;
2316
2626
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2317
2627
  #else
2318
2628
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
2319
2629
  int * x_qs = (int *) x_tile;
2320
2630
  float * x_df = (float *) (x_qs + txs.qs);
2321
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2631
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2322
2632
 
2323
2633
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
2324
2634
  constexpr int nrows = warp_size / threads_per_row;
@@ -2340,13 +2650,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2340
2650
  const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2341
2651
  const int k0 = kbx * (2 * QI4_NL) + kqsx;
2342
2652
 
2343
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2653
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2344
2654
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2345
2655
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
2346
2656
  #else
2347
2657
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2348
2658
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
2349
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2659
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2350
2660
  }
2351
2661
 
2352
2662
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
@@ -2363,11 +2673,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2363
2673
 
2364
2674
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
2365
2675
 
2366
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2676
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2367
2677
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
2368
2678
  #else
2369
2679
  x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
2370
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2680
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2371
2681
  }
2372
2682
  }
2373
2683
 
@@ -2376,14 +2686,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2376
2686
  constexpr int nwarps = mmq_get_nwarps_device();
2377
2687
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2378
2688
 
2379
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2689
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2380
2690
  int * x_qs = (int *) x_tile;
2381
2691
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2382
2692
  #else
2383
2693
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
2384
2694
  int * x_qs = (int *) x_tile;
2385
2695
  float * x_df = (float *) (x_qs + txs.qs);
2386
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2696
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2387
2697
 
2388
2698
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
2389
2699
  constexpr int nrows = warp_size / threads_per_row;
@@ -2405,31 +2715,31 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2405
2715
 
2406
2716
  #pragma unroll
2407
2717
  for (int l = 0; l < QR2_XXS; ++l) {
2408
- const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
2409
- const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
2718
+ const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];
2719
+ const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));
2410
2720
 
2411
- const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
2412
- const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
2721
+ const int signs0 = __vcmpne4(signs & 0x08040201, 0);
2722
+ const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
2413
2723
 
2414
- const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
2415
- const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
2724
+ const int signs1 = __vcmpne4(signs & 0x80402010, 0);
2725
+ const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
2416
2726
 
2417
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2727
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2418
2728
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
2419
2729
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
2420
2730
  #else
2421
2731
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
2422
2732
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
2423
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2733
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2424
2734
  }
2425
2735
 
2426
- const int ls = aux32 >> 28;
2736
+ const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
2427
2737
  const float d = bxi->d;
2428
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2429
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
2738
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2739
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
2430
2740
  #else
2431
- x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
2432
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2741
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
2742
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2433
2743
  }
2434
2744
  }
2435
2745
 
@@ -2438,14 +2748,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2438
2748
  constexpr int nwarps = mmq_get_nwarps_device();
2439
2749
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2440
2750
 
2441
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2751
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2442
2752
  int * x_qs = (int *) x_tile;
2443
2753
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2444
2754
  #else
2445
2755
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
2446
2756
  int * x_qs = (int *) x_tile;
2447
2757
  float * x_df = (float *) (x_qs + txs.qs);
2448
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2758
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2449
2759
 
2450
2760
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
2451
2761
  constexpr int nrows = warp_size / threads_per_row;
@@ -2466,30 +2776,33 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2466
2776
 
2467
2777
  #pragma unroll
2468
2778
  for (int l = 0; l < QR2_XS; ++l) {
2469
- const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
2470
- const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
2779
+ const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];
2780
+ const uint32_t signs = unpack_ksigns(q2[l] >> 9);
2471
2781
 
2472
- const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
2473
- const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
2782
+ const int signs0 = __vcmpne4(signs & 0x08040201, 0);
2783
+ const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2474
2784
 
2475
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2785
+ const int signs1 = __vcmpne4(signs & 0x80402010, 0);
2786
+ const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2787
+
2788
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2476
2789
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2477
2790
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2478
2791
  #else
2479
2792
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2480
2793
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2481
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2794
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2482
2795
  }
2483
2796
 
2484
2797
  const int ls = bxi->scales[kqsx];
2485
2798
  const float d = bxi->d;
2486
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2799
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2487
2800
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2488
2801
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2489
2802
  #else
2490
2803
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2491
2804
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2492
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2805
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2493
2806
  }
2494
2807
  }
2495
2808
 
@@ -2498,15 +2811,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2498
2811
  constexpr int nwarps = mmq_get_nwarps_device();
2499
2812
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2500
2813
 
2501
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2814
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2502
2815
  int * x_qs = (int *) x_tile;
2503
2816
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2504
2817
  #else
2505
2818
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
2506
2819
  int * x_qs = (int *) x_tile;
2507
2820
  float * x_df = (float *) (x_qs + txs.qs);
2508
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2509
-
2821
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2510
2822
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
2511
2823
  constexpr int nrows = warp_size / threads_per_row;
2512
2824
  const int kqsx = threadIdx.x % threads_per_row;
@@ -2539,24 +2851,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2539
2851
  const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
2540
2852
  const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
2541
2853
 
2542
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2854
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2543
2855
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2544
2856
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2545
2857
  #else
2546
2858
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2547
2859
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2548
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2860
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2549
2861
  }
2550
2862
 
2551
2863
  const int ls = bxi->scales[kqsx];
2552
2864
  const float d = bxi->d;
2553
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2865
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2554
2866
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2555
2867
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2556
2868
  #else
2557
2869
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2558
2870
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2559
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2871
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2560
2872
  }
2561
2873
  }
2562
2874
 
@@ -2565,14 +2877,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2565
2877
  constexpr int nwarps = mmq_get_nwarps_device();
2566
2878
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2567
2879
 
2568
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2880
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2569
2881
  int * x_qs = (int *) x_tile;
2570
2882
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2571
2883
  #else
2572
2884
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
2573
2885
  int * x_qs = (int *) x_tile;
2574
2886
  float * x_df = (float *) (x_qs + txs.qs);
2575
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2887
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2576
2888
 
2577
2889
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
2578
2890
  constexpr int nrows = warp_size / threads_per_row;
@@ -2595,28 +2907,30 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2595
2907
  #pragma unroll
2596
2908
  for (int l = 0; l < QR3_XXS; ++l) {
2597
2909
  const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
2910
+ const uint32_t signs = unpack_ksigns(aux32 >> (7*l));
2598
2911
 
2599
- const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
2912
+ const int signs0 = __vcmpne4(signs & 0x08040201, 0);
2913
+ const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2600
2914
 
2601
- const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
2602
- const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
2915
+ const int signs1 = __vcmpne4(signs & 0x80402010, 0);
2916
+ const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2603
2917
 
2604
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2918
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2605
2919
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
2606
2920
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
2607
2921
  #else
2608
2922
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2609
2923
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2610
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2924
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2611
2925
  }
2612
2926
 
2613
2927
  const int ls = aux32 >> 28;
2614
2928
  const float d = bxi->d;
2615
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2929
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2616
2930
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2617
2931
  #else
2618
2932
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2619
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2933
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2620
2934
  }
2621
2935
  }
2622
2936
 
@@ -2625,14 +2939,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2625
2939
  constexpr int nwarps = mmq_get_nwarps_device();
2626
2940
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2627
2941
 
2628
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2942
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2629
2943
  int * x_qs = (int *) x_tile;
2630
2944
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2631
2945
  #else
2632
2946
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2633
2947
  int * x_qs = (int *) x_tile;
2634
2948
  float * x_df = (float *) (x_qs + txs.qs);
2635
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2949
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2636
2950
 
2637
2951
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
2638
2952
  constexpr int nrows = warp_size / threads_per_row;
@@ -2668,22 +2982,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2668
2982
  const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2669
2983
  const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2670
2984
 
2671
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2985
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2672
2986
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
2673
2987
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
2674
2988
  #else
2675
2989
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
2676
2990
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
2677
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2991
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2678
2992
  }
2679
2993
 
2680
2994
  const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
2681
2995
  const float d = bxi->d;
2682
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2996
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2683
2997
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2684
2998
  #else
2685
2999
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
2686
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3000
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2687
3001
  }
2688
3002
  }
2689
3003
 
@@ -2692,14 +3006,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2692
3006
  constexpr int nwarps = mmq_get_nwarps_device();
2693
3007
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2694
3008
 
2695
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3009
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2696
3010
  int * x_qs = (int *) x_tile;
2697
3011
  half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
2698
3012
  #else
2699
3013
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2700
3014
  int * x_qs = (int *) x_tile;
2701
3015
  half2 * x_ds = (half2 *) (x_qs + txs.qs);
2702
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3016
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2703
3017
 
2704
3018
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
2705
3019
  constexpr int nrows = warp_size / threads_per_row;
@@ -2727,23 +3041,23 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2727
3041
  const int grid0 = (grid >> 0) & 0x0F0F0F0F;
2728
3042
  const int grid1 = (grid >> 4) & 0x0F0F0F0F;
2729
3043
 
2730
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3044
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2731
3045
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
2732
3046
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
2733
3047
  #else
2734
3048
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
2735
3049
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
2736
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3050
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2737
3051
  }
2738
3052
 
2739
3053
  const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
2740
3054
  const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
2741
3055
 
2742
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3056
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2743
3057
  x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
2744
3058
  #else
2745
3059
  x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
2746
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3060
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2747
3061
  }
2748
3062
  }
2749
3063
 
@@ -2752,14 +3066,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2752
3066
  constexpr int nwarps = mmq_get_nwarps_device();
2753
3067
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2754
3068
 
2755
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3069
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2756
3070
  int * x_qs = (int *) x_tile;
2757
3071
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2758
3072
  #else
2759
3073
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2760
3074
  int * x_qs = (int *) x_tile;
2761
3075
  float * x_df = (float *) (x_qs + txs.qs);
2762
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3076
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2763
3077
 
2764
3078
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
2765
3079
  constexpr int nrows = warp_size / threads_per_row;
@@ -2779,13 +3093,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2779
3093
  const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2780
3094
  const int k0 = 8 * (kqsx / 4) + kqsx % 4;
2781
3095
 
2782
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3096
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2783
3097
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2784
3098
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2785
3099
  #else
2786
3100
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2787
3101
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
2788
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3102
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2789
3103
  }
2790
3104
 
2791
3105
  constexpr int rows_per_warp = warp_size / 8;
@@ -2804,11 +3118,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2804
3118
  const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
2805
3119
  | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
2806
3120
 
2807
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3121
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2808
3122
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
2809
3123
  #else
2810
3124
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2811
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3125
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2812
3126
  }
2813
3127
  }
2814
3128
 
@@ -2848,9 +3162,9 @@ static __device__ __forceinline__ void mmq_write_back_mma(
2848
3162
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
2849
3163
  constexpr int nwarps = mmq_get_nwarps_device();
2850
3164
 
2851
- #if defined(AMD_MFMA_AVAILABLE)
3165
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2852
3166
  constexpr int tileC_IJ = mmq_get_granularity_device(0);
2853
- typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
3167
+ typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;
2854
3168
  constexpr int rows_per_warp = granularity;
2855
3169
  #else
2856
3170
  typedef tile<16, 8, int> tile_C;
@@ -2859,11 +3173,11 @@ static __device__ __forceinline__ void mmq_write_back_mma(
2859
3173
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2860
3174
 
2861
3175
  const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
2862
- #if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
3176
+ #if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2863
3177
  static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
2864
3178
  #else
2865
3179
  GGML_UNUSED(nwarps);
2866
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3180
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2867
3181
 
2868
3182
  #pragma unroll
2869
3183
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
@@ -2937,8 +3251,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
2937
3251
  template <int mmq_x, int mmq_y, bool need_check>
2938
3252
  struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
2939
3253
  static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
3254
+ #ifdef BLACKWELL_MMA_AVAILABLE
3255
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
3256
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
3257
+ #else
2940
3258
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
2941
3259
  static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3260
+ #endif // BLACKWELL_MMA_AVAILABLE
2942
3261
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2943
3262
  };
2944
3263
 
@@ -3063,25 +3382,34 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
3063
3382
  int * tile_y = data_mul_mat_q + mmq_x;
3064
3383
  int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
3065
3384
 
3066
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3385
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3067
3386
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
3068
3387
  constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
3069
3388
  #else
3070
3389
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
3071
3390
  constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
3072
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3391
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3073
3392
 
3074
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3393
+ #if defined(BLACKWELL_MMA_AVAILABLE)
3394
+ // FP4 tile stores 8 blocks
3395
+ constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
3396
+ #else
3397
+ constexpr int ne_block = 4 * QK8_1;
3398
+ #endif // defined(BLACKWELL_MMA_AVAILABLE)
3399
+
3400
+ constexpr int ITER_K = get_iter_k(type);
3401
+ constexpr int blocks_per_iter = ITER_K / qk;
3075
3402
 
3076
3403
  float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
3077
3404
 
3405
+ constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
3406
+
3078
3407
  for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
3079
3408
  load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
3080
-
3081
3409
  {
3082
- const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
3410
+ const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
3083
3411
  #pragma unroll
3084
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
3412
+ for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
3085
3413
  int l = l0 + threadIdx.y*warp_size + threadIdx.x;
3086
3414
 
3087
3415
  tile_y[l] = by0[l];
@@ -3095,9 +3423,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
3095
3423
  __syncthreads();
3096
3424
 
3097
3425
  {
3098
- const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
3426
+ const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
3099
3427
  #pragma unroll
3100
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
3428
+ for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
3101
3429
  int l = l0 + threadIdx.y*warp_size + threadIdx.x;
3102
3430
 
3103
3431
  tile_y[l] = by0[l];
@@ -3229,8 +3557,10 @@ static __global__ void mul_mat_q(
3229
3557
  }
3230
3558
  #endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3231
3559
 
3560
+ constexpr int ITER_K = get_iter_k(type);
3561
+
3232
3562
  const int64_t blocks_per_ne00 = ncols_x / qk;
3233
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3563
+ constexpr int blocks_per_iter = ITER_K / qk;
3234
3564
 
3235
3565
  // kbc == k block continuous, current index in continuous ijk space.
3236
3566
  int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
@@ -3291,7 +3621,7 @@ static __global__ void mul_mat_q(
3291
3621
  __syncthreads();
3292
3622
  }
3293
3623
 
3294
- offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
3624
+ offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
3295
3625
  offset_dst += it*mmq_y;
3296
3626
 
3297
3627
  const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@@ -3358,7 +3688,7 @@ static __global__ void mul_mat_q(
3358
3688
  __syncthreads();
3359
3689
  }
3360
3690
 
3361
- offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
3691
+ offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
3362
3692
  offset_dst += it*mmq_y;
3363
3693
 
3364
3694
  const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@@ -3372,16 +3702,25 @@ static __global__ void mul_mat_q(
3372
3702
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
3373
3703
  }
3374
3704
 
3375
-
3376
3705
  template <ggml_type type, int mmq_x, bool need_check>
3377
- static __global__ void mul_mat_q_stream_k_fixup(
3378
- const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
3379
- const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
3380
- const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
3381
- const int ncols_max) {
3706
+ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
3707
+ const int32_t * expert_bounds,
3708
+ float * __restrict__ dst,
3709
+ const float * __restrict__ tmp_last_tile,
3710
+ const int ncols_x,
3711
+ const int nrows_x,
3712
+ const int ncols_dst,
3713
+ const size_t stride_col_dst,
3714
+ const int nchannels_y,
3715
+ const size_t stride_channel_dst,
3716
+ const int nsamples_y,
3717
+ const size_t stride_sample_dst,
3718
+ const int ncols_max) {
3382
3719
  constexpr int mmq_y = get_mmq_y_device();
3383
3720
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
3384
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3721
+ constexpr int ITER_K = get_iter_k(type);
3722
+
3723
+ constexpr int blocks_per_iter = ITER_K / qk;
3385
3724
  const int64_t blocks_per_ne00 = ncols_x / qk;
3386
3725
 
3387
3726
  constexpr int nwarps = mmq_get_nwarps_device();
@@ -3494,7 +3833,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
3494
3833
  const int col_diff = col_high - col_low;
3495
3834
 
3496
3835
  for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
3497
- ids_dst_shared[j] = ids_dst[col_low + j];
3836
+ ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
3498
3837
  }
3499
3838
  __syncthreads();
3500
3839
 
@@ -3538,8 +3877,8 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
3538
3877
  const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
3539
3878
  const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
3540
3879
  const size_t nbs_ids = mmq_x*sizeof(int);
3541
- const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3542
- const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
3880
+ const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3881
+ const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
3543
3882
  return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
3544
3883
  }
3545
3884
 
@@ -3755,4 +4094,4 @@ void ggml_cuda_op_mul_mat_q(
3755
4094
  const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
3756
4095
  const int64_t src1_padded_row_size, cudaStream_t stream);
3757
4096
 
3758
- bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);
4097
+ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);