whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
11
11
 
12
12
  #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
13
13
  #define MMQ_ITER_K 256
14
+ #define MMQ_ITER_K_MXFP4_FP4 512
14
15
  #define MMQ_NWARPS 8
15
16
 
16
17
  typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
@@ -44,8 +45,15 @@ struct block_q8_1_mmq {
44
45
  };
45
46
  int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
46
47
  };
48
+
49
+ struct block_fp4_mmq {
50
+ uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
51
+ int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
52
+ };
53
+
47
54
  static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
48
55
  static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
56
+ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
49
57
 
50
58
  static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
51
59
  switch (type_x) {
@@ -92,7 +100,7 @@ struct tile_x_sizes {
92
100
  };
93
101
 
94
102
  static int get_mmq_x_max_host(const int cc) {
95
- return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 :
103
+ return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
96
104
  GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
97
105
  #ifdef GGML_CUDA_FORCE_MMQ
98
106
  128 : 64;
@@ -102,7 +110,7 @@ static int get_mmq_x_max_host(const int cc) {
102
110
  }
103
111
 
104
112
  static constexpr __device__ int get_mmq_x_max_device() {
105
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
113
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
106
114
  return 128;
107
115
  #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
108
116
 
@@ -121,7 +129,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
121
129
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
122
130
 
123
131
  #endif // defined(GGML_USE_HIP)
124
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
132
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
125
133
  }
126
134
 
127
135
  static int get_mmq_y_host(const int cc) {
@@ -129,6 +137,14 @@ static int get_mmq_y_host(const int cc) {
129
137
  ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
130
138
  }
131
139
 
140
+ static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
141
+ #if defined(BLACKWELL_MMA_AVAILABLE)
142
+ return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
143
+ #else
144
+ return MMQ_ITER_K;
145
+ #endif // defined(BLACKWELL_MMA_AVAILABLE)
146
+ }
147
+
132
148
  static constexpr __device__ int get_mmq_y_device() {
133
149
  #if defined(GGML_USE_HIP)
134
150
  #if defined(RDNA1)
@@ -191,6 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
191
207
  }
192
208
 
193
209
  #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
210
+ #define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
194
211
  #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
195
212
  #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
196
213
  #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
@@ -201,6 +218,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
201
218
  static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
202
219
  static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
203
220
  static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
221
+ static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
222
+ static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
204
223
 
205
224
  static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
206
225
  switch (type) {
@@ -209,6 +228,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
209
228
  case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
210
229
  case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
211
230
  case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
231
+ // tile sizes are the same for Q8_1 and FP4 for blackwell
212
232
  case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
213
233
  case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
214
234
  case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
@@ -228,10 +248,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
228
248
  }
229
249
 
230
250
  // block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
231
- #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
251
+ #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
252
+ #define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
232
253
 
233
254
  static int mmq_get_granularity_host(const int mmq_x, const int cc) {
234
- if (amd_mfma_available(cc)) {
255
+ if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
235
256
  return mmq_x >= 128 ? 32 : 16;
236
257
  } else if (turing_mma_available(cc) && mmq_x >= 48) {
237
258
  return 16;
@@ -240,7 +261,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) {
240
261
  }
241
262
  }
242
263
 
243
- #if defined(AMD_MFMA_AVAILABLE)
264
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
244
265
  static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
245
266
  return mmq_x >= 128 ? 32 : 16;
246
267
  }
@@ -265,7 +286,7 @@ static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
265
286
  #endif // (GGML_USE_HIP)
266
287
 
267
288
  static constexpr __device__ int mmq_get_nwarps_device() {
268
- #if defined(AMD_MFMA_AVAILABLE)
289
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
269
290
  return 8;
270
291
  #else
271
292
  return 256/ggml_cuda_get_physical_warp_size();
@@ -279,14 +300,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
279
300
  constexpr int nwarps = mmq_get_nwarps_device();
280
301
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
281
302
 
282
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
303
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
283
304
  int * x_qs = (int *) x_tile;
284
305
  float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
285
306
  #else
286
307
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
287
308
  int * x_qs = (int *) x_tile;
288
309
  float * x_df = (float *) (x_qs + txs.qs);
289
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
310
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
290
311
 
291
312
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
292
313
  constexpr int nrows = warp_size / threads_per_row;
@@ -305,7 +326,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
305
326
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
306
327
  const int qs0 = get_int_b2(bxi->qs, kqsx);
307
328
 
308
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
329
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
309
330
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
310
331
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
311
332
  #else
@@ -327,11 +348,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
327
348
 
328
349
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
329
350
 
330
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
351
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
331
352
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
332
353
  #else
333
354
  x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
334
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
355
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
335
356
  }
336
357
  }
337
358
 
@@ -382,14 +403,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
382
403
  constexpr int nwarps = mmq_get_nwarps_device();
383
404
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
384
405
 
385
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
406
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
386
407
  int * x_qs = (int *) x_tile;
387
408
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
388
409
  #else
389
410
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
390
411
  int * x_qs = (int *) x_tile;
391
412
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
392
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
413
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
393
414
 
394
415
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
395
416
  constexpr int nrows = warp_size / threads_per_row;
@@ -408,12 +429,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
408
429
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
409
430
  const int qs0 = get_int_b4(bxi->qs, kqsx);
410
431
 
411
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
432
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
412
433
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
413
434
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
414
435
  #else
415
436
  x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
416
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
437
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
417
438
  }
418
439
 
419
440
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
@@ -430,11 +451,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
430
451
 
431
452
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
432
453
 
433
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
454
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
434
455
  x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
435
456
  #else
436
457
  x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
437
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
458
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
438
459
  }
439
460
  }
440
461
 
@@ -485,14 +506,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
485
506
  constexpr int nwarps = mmq_get_nwarps_device();
486
507
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
487
508
 
488
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
509
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
489
510
  int * x_qs = (int *) x_tile;
490
511
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
491
512
  #else
492
513
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
493
514
  int * x_qs = (int *) x_tile;
494
515
  float * x_df = (float *) (x_qs + txs.qs);
495
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
516
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
496
517
 
497
518
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
498
519
  constexpr int nrows = warp_size / threads_per_row;
@@ -527,13 +548,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
527
548
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
528
549
  qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
529
550
 
530
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
551
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
531
552
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
532
553
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
533
554
  #else
534
555
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
535
556
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
536
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
557
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
537
558
  }
538
559
 
539
560
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
@@ -550,11 +571,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
550
571
 
551
572
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
552
573
 
553
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
574
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
554
575
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
555
576
  #else
556
577
  x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
557
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
578
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
558
579
  }
559
580
  }
560
581
 
@@ -563,14 +584,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
563
584
  constexpr int nwarps = mmq_get_nwarps_device();
564
585
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
565
586
 
566
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
587
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
567
588
  int * x_qs = (int *) x_tile;
568
589
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
569
590
  #else
570
591
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
571
592
  int * x_qs = (int *) x_tile;
572
593
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
573
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
594
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
574
595
 
575
596
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
576
597
  constexpr int nrows = warp_size / threads_per_row;
@@ -603,13 +624,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
603
624
  qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
604
625
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
605
626
 
606
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
627
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
607
628
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
608
629
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
609
630
  #else
610
631
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
611
632
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
612
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
633
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
613
634
  }
614
635
 
615
636
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
@@ -626,11 +647,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
626
647
 
627
648
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
628
649
 
629
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
650
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
630
651
  x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
631
652
  #else
632
653
  x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
633
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
654
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
634
655
  }
635
656
  }
636
657
 
@@ -639,14 +660,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
639
660
  constexpr int nwarps = mmq_get_nwarps_device();
640
661
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
641
662
 
642
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
663
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
643
664
  int * x_qs = (int *) x_tile;
644
665
  float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
645
666
  #else
646
667
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
647
668
  int * x_qs = (int *) x_tile;
648
669
  float * x_df = (float *) (x_qs + txs.qs);
649
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
670
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
650
671
 
651
672
  // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
652
673
  constexpr int threads_per_row = 32;
@@ -665,13 +686,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
665
686
 
666
687
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
667
688
 
668
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
689
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
669
690
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
670
691
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
671
692
  #else
672
693
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
673
694
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
674
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
695
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
675
696
  }
676
697
 
677
698
  constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
@@ -688,11 +709,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
688
709
 
689
710
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
690
711
 
691
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
712
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
692
713
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
693
714
  #else
694
715
  x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
695
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
716
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
696
717
  }
697
718
  }
698
719
 
@@ -701,14 +722,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
701
722
  constexpr int nwarps = mmq_get_nwarps_device();
702
723
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
703
724
 
704
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
725
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
705
726
  int * x_qs = (int *) x_tile;
706
727
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
707
728
  #else
708
729
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
709
730
  int * x_qs = (int *) x_tile;
710
731
  float * x_df = (float *) (x_qs + txs.qs);
711
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
732
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
712
733
 
713
734
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
714
735
  constexpr int nrows = warp_size / threads_per_row;
@@ -730,13 +751,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
730
751
  const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
731
752
  const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
732
753
 
733
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
754
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
734
755
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
735
756
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
736
757
  #else
737
758
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
738
759
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
739
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
760
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
740
761
  }
741
762
 
742
763
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
@@ -753,11 +774,55 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
753
774
 
754
775
  const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
755
776
 
756
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
777
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
757
778
  x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
758
779
  #else
759
780
  x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
760
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
781
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
782
+ }
783
+ }
784
+
785
+ template <int mmq_y, bool need_check>
786
+ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
787
+ int * __restrict__ x_tile,
788
+ const int kbx0,
789
+ const int i_max,
790
+ const int stride) {
791
+ constexpr int nwarps = mmq_get_nwarps_device();
792
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
793
+
794
+ int * x_qs = (int *) x_tile;
795
+ uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
796
+
797
+ const int txi = threadIdx.x;
798
+
799
+ constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
800
+
801
+ constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block
802
+ constexpr int rows_per_warp = warp_size / threads_per_row;
803
+ const int kbx = txi % threads_per_row;
804
+ const int row_in_warp = txi / threads_per_row;
805
+
806
+ #pragma unroll
807
+ for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
808
+ int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
809
+
810
+ if constexpr (need_check) {
811
+ i = min(i, i_max);
812
+ }
813
+
814
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
815
+
816
+ // quantize_mxfp4_mmq permutes nibbles to match the quantized format
817
+ const int k0 = kbx * 4;
818
+ memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
819
+
820
+ // Load E8M0 scales: pack 2 consecutive scales into one uint32
821
+ if (kbx % 2 == 0) {
822
+ uint32_t e = bxi->e;
823
+ e |= ((bxi + 1)->e << 8);
824
+ x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
825
+ }
761
826
  }
762
827
  }
763
828
 
@@ -796,10 +861,11 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
796
861
  template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
797
862
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
798
863
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
799
- #if defined(AMD_MFMA_AVAILABLE)
800
- typedef tile<16, 8, int> tile_A;
801
- typedef tile<16, 8, int> tile_B;
802
- typedef tile<16, 16, int> tile_C;
864
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
865
+ constexpr data_layout input_layout = get_input_data_layout();
866
+ typedef tile<16, 8, int, input_layout> tile_A;
867
+ typedef tile<16, 8, int, input_layout> tile_B;
868
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
803
869
 
804
870
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
805
871
  constexpr int rows_per_warp = granularity;
@@ -927,7 +993,79 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
927
993
  }
928
994
  }
929
995
  }
930
- #endif // defined(AMD_MFMA_AVAILABLE)
996
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
997
+ }
998
+
999
+ template <int mmq_x, int mmq_y>
1000
+ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
1001
+ const int * __restrict__ y,
1002
+ float * __restrict__ sum,
1003
+ const int k00) {
1004
+ typedef tile<16, 8, int> tile_A;
1005
+ typedef tile<8, 8, int> tile_B;
1006
+ typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
1007
+
1008
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1009
+ constexpr int rows_per_warp = 2 * granularity;
1010
+ constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
1011
+
1012
+ y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
1013
+
1014
+ // Match layout from load_tiles_mxfp4_fp4
1015
+ const int * x_qs = (const int *) x;
1016
+ const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
1017
+ const int * y_qs = (const int *) y + 4;
1018
+ const uint32_t * y_sc = (const uint32_t *) y;
1019
+
1020
+ // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
1021
+ tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1022
+ uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1023
+
1024
+ // Block scale
1025
+ // Each thread has to point to a 4 byte scale value
1026
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1027
+
1028
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1029
+
1030
+ #pragma unroll
1031
+ for (int n = 0; n < ntx; ++n) {
1032
+ #pragma unroll
1033
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1034
+ const int k0 = k00 + k01;
1035
+
1036
+ load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
1037
+ MMQ_MMA_TILE_X_K_FP4);
1038
+
1039
+ // based on block-scaling document, 2 threads in each quad need to supply to the scale value
1040
+ const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
1041
+ scaleA[n][k01 / (2 * QI_MXFP4)] =
1042
+ *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
1043
+ }
1044
+ }
1045
+
1046
+ #pragma unroll
1047
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
1048
+ #pragma unroll
1049
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1050
+ tile_B B;
1051
+ uint32_t scaleB; // 2xN scales
1052
+
1053
+ load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
1054
+
1055
+ scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
1056
+
1057
+ #pragma unroll
1058
+ for (int n = 0; n < ntx; ++n) {
1059
+ tile_C C;
1060
+
1061
+ mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
1062
+ #pragma unroll
1063
+ for (int l = 0; l < tile_C::ne; ++l) {
1064
+ sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
1065
+ }
1066
+ }
1067
+ }
1068
+ }
931
1069
  }
932
1070
 
933
1071
  template <int mmq_x, int mmq_y>
@@ -965,10 +1103,11 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
965
1103
  template <int mmq_x, int mmq_y>
966
1104
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
967
1105
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
968
- #if defined(AMD_MFMA_AVAILABLE)
969
- typedef tile<16, 8, int> tile_A;
970
- typedef tile<16, 8, int> tile_B;
971
- typedef tile<16, 16, int> tile_C;
1106
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1107
+ constexpr data_layout input_layout = get_input_data_layout();
1108
+ typedef tile<16, 8, int, input_layout> tile_A;
1109
+ typedef tile<16, 8, int, input_layout> tile_B;
1110
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
972
1111
 
973
1112
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
974
1113
  constexpr int rows_per_warp = granularity;
@@ -1087,7 +1226,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
1087
1226
  }
1088
1227
  }
1089
1228
  }
1090
- #endif // defined(AMD_MFMA_AVAILABLE)
1229
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1091
1230
  }
1092
1231
 
1093
1232
  // Used for Q3_K, IQ2_S, and IQ2_XS
@@ -1130,10 +1269,11 @@ template <int mmq_x, int mmq_y>
1130
1269
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1131
1270
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1132
1271
  #if defined(AMD_MFMA_AVAILABLE)
1133
- typedef tile<16, 8, int> tile_A;
1134
- typedef tile<16, 8, int> tile_B;
1135
- typedef tile<16, 16, int> tile_C;
1136
- typedef tile<64, 2, int> tile_load;
1272
+ constexpr data_layout input_layout = get_input_data_layout();
1273
+ typedef tile<16, 8, int, input_layout> tile_A;
1274
+ typedef tile<16, 8, int, input_layout> tile_B;
1275
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1276
+ typedef tile<64, 2, int, input_layout> tile_load;
1137
1277
 
1138
1278
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
1139
1279
  constexpr int rows_per_warp = granularity;
@@ -1170,6 +1310,55 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1170
1310
  tile_C C;
1171
1311
  mma(C, A[n], B[0]);
1172
1312
 
1313
+ #pragma unroll
1314
+ for (int l = 0; l < tile_C::ne; ++l) {
1315
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1316
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
1317
+ }
1318
+ }
1319
+ }
1320
+ }
1321
+ #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1322
+ constexpr data_layout input_layout = get_input_data_layout();
1323
+ typedef tile<16, 4, int, input_layout> tile_A;
1324
+ typedef tile<16, 4, int, input_layout> tile_B;
1325
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1326
+
1327
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1328
+ constexpr int rows_per_warp = granularity;
1329
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1330
+
1331
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1332
+
1333
+ const int * x_qs = (const int *) x;
1334
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1335
+ const int * y_qs = (const int *) y + 4;
1336
+ const float * y_df = (const float *) y;
1337
+
1338
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1339
+
1340
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1341
+ const int k0 = k00 + k01;
1342
+
1343
+ tile_A A[ntx];
1344
+ #pragma unroll
1345
+ for (int n = 0; n < ntx; ++n) {
1346
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1347
+ }
1348
+
1349
+ #pragma unroll
1350
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1351
+ tile_B B;
1352
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1353
+
1354
+ const int j = j0 + tile_C::get_j(0);
1355
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
1356
+
1357
+ #pragma unroll
1358
+ for (int n = 0; n < ntx; ++n) {
1359
+ tile_C C;
1360
+ mma(C, A[n], B);
1361
+
1173
1362
  #pragma unroll
1174
1363
  for (int l = 0; l < tile_C::ne; ++l) {
1175
1364
  const int i = i0 + n*tile_C::I + tile_C::get_i(l);
@@ -1257,21 +1446,21 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1257
1446
  #else
1258
1447
  GGML_UNUSED_VARS(x, y, sum, k00);
1259
1448
  NO_DEVICE_CODE;
1260
- #endif // AMD_MFMA_AVAILABLE
1449
+ #endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
1261
1450
  }
1262
1451
 
1263
1452
  template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
1264
1453
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1265
1454
  constexpr int nwarps = mmq_get_nwarps_device();
1266
1455
 
1267
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1456
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1268
1457
  int * x_qs = (int *) x_tile;
1269
1458
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1270
1459
  #else
1271
1460
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1272
1461
  int * x_qs = (int *) x_tile;
1273
1462
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1274
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1463
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1275
1464
 
1276
1465
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
1277
1466
  constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
@@ -1295,11 +1484,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1295
1484
 
1296
1485
  const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
1297
1486
 
1298
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1487
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1299
1488
  x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
1300
1489
  #else
1301
1490
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1302
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1491
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1303
1492
  }
1304
1493
 
1305
1494
  const int sc_m = bxi->scales[kqsx];
@@ -1310,11 +1499,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1310
1499
  const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
1311
1500
  #endif // FAST_FP16_AVAILABLE
1312
1501
 
1313
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1502
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1314
1503
  x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
1315
1504
  #else
1316
1505
  x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
1317
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1506
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1318
1507
  }
1319
1508
  }
1320
1509
 
@@ -1387,10 +1576,11 @@ template <int mmq_x, int mmq_y>
1387
1576
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1388
1577
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1389
1578
  #if defined(AMD_MFMA_AVAILABLE)
1390
- typedef tile<16, 8, int> tile_A;
1391
- typedef tile<16, 8, int> tile_B;
1392
- typedef tile<16, 16, int> tile_C;
1393
- typedef tile<64, 2, int> tile_load;
1579
+ constexpr data_layout input_layout = get_input_data_layout();
1580
+ typedef tile<16, 8, int, input_layout> tile_A;
1581
+ typedef tile<16, 8, int, input_layout> tile_B;
1582
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1583
+ typedef tile<64, 2, int, input_layout> tile_load;
1394
1584
 
1395
1585
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
1396
1586
  constexpr int rows_per_warp = granularity;
@@ -1438,6 +1628,74 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1438
1628
  tile_C Cd;
1439
1629
  mma(Cd, A[n], B[0]);
1440
1630
 
1631
+ #pragma unroll
1632
+ for (int l = 0; l < tile_C::ne; ++l) {
1633
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1634
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
1635
+ float tmp = Cd.x[l]*dm.x;
1636
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1637
+ tmp -= Cm.x[l]*dm.y;
1638
+ }
1639
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
1640
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
1641
+ }
1642
+ }
1643
+ }
1644
+ }
1645
+ #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1646
+ constexpr data_layout input_layout = get_input_data_layout();
1647
+ typedef tile<16, 4, int, input_layout> tile_A;
1648
+ typedef tile<16, 4, int, input_layout> tile_B;
1649
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1650
+
1651
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1652
+ constexpr int rows_per_warp = granularity;
1653
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1654
+
1655
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1656
+
1657
+ const int * x_qs = (const int *) x;
1658
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1659
+ const int * y_qs = (const int *) y + 4;
1660
+ const half2 * y_ds = (const half2 *) y;
1661
+
1662
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1663
+
1664
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1665
+ const int k0 = k00 + k01;
1666
+
1667
+ tile_A A[ntx];
1668
+ #pragma unroll
1669
+ for (int n = 0; n < ntx; ++n) {
1670
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1671
+ }
1672
+
1673
+ #pragma unroll
1674
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1675
+ tile_B B;
1676
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1677
+
1678
+ const int j = j0 + tile_C::get_j(0);
1679
+ const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
1680
+ const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
1681
+ : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1682
+ : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
1683
+
1684
+ tile_C Cm;
1685
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1686
+ tile_A A1;
1687
+ #pragma unroll
1688
+ for (int l = 0; l < tile_A::ne; ++l) {
1689
+ A1.x[l] = 0x01010101;
1690
+ }
1691
+ mma(Cm, A1, B);
1692
+ }
1693
+
1694
+ #pragma unroll
1695
+ for (int n = 0; n < ntx; ++n) {
1696
+ tile_C Cd;
1697
+ mma(Cd, A[n], B);
1698
+
1441
1699
  #pragma unroll
1442
1700
  for (int l = 0; l < tile_C::ne; ++l) {
1443
1701
  const int i = i0 + n*tile_C::I + tile_C::get_i(l);
@@ -1574,7 +1832,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1574
1832
  #else
1575
1833
  GGML_UNUSED_VARS(x, y, sum, k00);
1576
1834
  NO_DEVICE_CODE;
1577
- #endif // AMD_MFMA_AVAILABLE
1835
+ #endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
1578
1836
  }
1579
1837
 
1580
1838
  template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
@@ -1582,7 +1840,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1582
1840
  constexpr int nwarps = mmq_get_nwarps_device();
1583
1841
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1584
1842
 
1585
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1843
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1586
1844
  int * x_qs = (int *) x_tile;
1587
1845
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1588
1846
  #else
@@ -1618,11 +1876,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1618
1876
 
1619
1877
  const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
1620
1878
 
1621
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1879
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1622
1880
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
1623
1881
  #else
1624
1882
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1625
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1883
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1626
1884
  }
1627
1885
  }
1628
1886
 
@@ -1649,7 +1907,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1649
1907
 
1650
1908
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1651
1909
 
1652
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1910
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1653
1911
  const int8_t * sc8 = (const int8_t *) &sc;
1654
1912
  const float d = bxi->d;
1655
1913
 
@@ -1659,10 +1917,10 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1659
1917
  }
1660
1918
  #else
1661
1919
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
1662
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1920
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1663
1921
  }
1664
1922
 
1665
- #if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1923
+ #if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE))
1666
1924
  #pragma unroll
1667
1925
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1668
1926
  int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
@@ -1675,7 +1933,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1675
1933
 
1676
1934
  x_df[i] = bxi->d;
1677
1935
  }
1678
- #endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1936
+ #endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE)
1679
1937
  }
1680
1938
 
1681
1939
  template <int mmq_x, int mmq_y>
@@ -1728,7 +1986,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1728
1986
  constexpr int nwarps = mmq_get_nwarps_device();
1729
1987
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1730
1988
 
1731
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1989
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1732
1990
  int * x_qs = (int *) x_tile;
1733
1991
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1734
1992
  #else
@@ -1736,7 +1994,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1736
1994
  int * x_qs = (int *) x_tile;
1737
1995
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1738
1996
  int * x_sc = (int *) (x_dm + txs.dm);
1739
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1997
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1740
1998
 
1741
1999
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
1742
2000
  constexpr int nrows = warp_size / threads_per_row;
@@ -1753,19 +2011,19 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1753
2011
  const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1754
2012
  const int qs0 = get_int_b4(bxi->qs, txi);
1755
2013
 
1756
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2014
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1757
2015
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
1758
2016
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
1759
2017
  #else
1760
2018
  x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
1761
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2019
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1762
2020
  }
1763
2021
 
1764
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2022
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1765
2023
  constexpr int rows_per_warp = warp_size / 2;
1766
2024
  #pragma unroll
1767
2025
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1768
- #if defined(AMD_MFMA_AVAILABLE)
2026
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1769
2027
  // Need if on AMD instead of % because warp_size == 64
1770
2028
  // This causes double work and throughput loss (MI300X)
1771
2029
  // H100 loses about 100 t/s with 'if' condition over '%'
@@ -1774,7 +2032,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1774
2032
  #else
1775
2033
  int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
1776
2034
  {
1777
- #endif // defined(AMD_MFMA_AVAILABLE)
2035
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1778
2036
  if (need_check) {
1779
2037
  i = min(i, i_max);
1780
2038
  }
@@ -1829,7 +2087,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1829
2087
 
1830
2088
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1831
2089
  }
1832
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2090
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1833
2091
  }
1834
2092
 
1835
2093
  template <int mmq_x, int mmq_y>
@@ -1872,7 +2130,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1872
2130
  constexpr int nwarps = mmq_get_nwarps_device();
1873
2131
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1874
2132
 
1875
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2133
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1876
2134
  int * x_qs = (int *) x_tile;
1877
2135
  half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
1878
2136
  #else
@@ -1908,16 +2166,16 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1908
2166
  const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
1909
2167
  const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
1910
2168
 
1911
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2169
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1912
2170
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
1913
2171
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
1914
2172
  #else
1915
2173
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
1916
2174
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
1917
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2175
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1918
2176
  }
1919
2177
 
1920
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2178
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1921
2179
  constexpr int rows_per_warp = warp_size / 2;
1922
2180
  #pragma unroll
1923
2181
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
@@ -1930,7 +2188,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1930
2188
  #else
1931
2189
  int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
1932
2190
  {
1933
- #endif // defined(AMD_MFMA_AVAILABLE)
2191
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1934
2192
  if (need_check) {
1935
2193
  i = min(i, i_max);
1936
2194
  }
@@ -1986,7 +2244,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1986
2244
 
1987
2245
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1988
2246
  }
1989
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2247
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1990
2248
  }
1991
2249
 
1992
2250
  template <int mmq_x, int mmq_y>
@@ -2029,7 +2287,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2029
2287
  constexpr int nwarps = mmq_get_nwarps_device();
2030
2288
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2031
2289
 
2032
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2290
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2033
2291
  int * x_qs = (int *) x_tile;
2034
2292
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2035
2293
  int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
@@ -2038,7 +2296,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2038
2296
  int * x_qs = (int *) x_tile;
2039
2297
  float * x_df = (float *) (x_qs + txs.qs);
2040
2298
  int * x_sc = (int *) (x_df + txs.dm);
2041
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2299
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2042
2300
 
2043
2301
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
2044
2302
  constexpr int nrows = warp_size / threads_per_row;
@@ -2065,13 +2323,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2065
2323
  const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
2066
2324
  const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
2067
2325
 
2068
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2326
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2069
2327
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2070
2328
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2071
2329
  #else
2072
2330
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2073
2331
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2074
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2332
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2075
2333
  }
2076
2334
 
2077
2335
  #pragma unroll
@@ -2084,11 +2342,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2084
2342
 
2085
2343
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
2086
2344
 
2087
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2345
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2088
2346
  x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
2089
2347
  #else
2090
2348
  x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
2091
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2349
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2092
2350
  }
2093
2351
 
2094
2352
  constexpr int rows_per_warp = warp_size / 4;
@@ -2102,11 +2360,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2102
2360
 
2103
2361
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
2104
2362
 
2105
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2363
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2106
2364
  x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
2107
2365
  #else
2108
2366
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
2109
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2367
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2110
2368
  }
2111
2369
  }
2112
2370
 
@@ -2149,10 +2407,11 @@ template <int mmq_x, int mmq_y>
2149
2407
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2150
2408
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2151
2409
  #if defined(AMD_MFMA_AVAILABLE)
2152
- typedef tile<16, 8, int> tile_A;
2153
- typedef tile<16, 8, int> tile_B;
2154
- typedef tile<16, 16, int> tile_C;
2155
- typedef tile<64, 2, int> tile_load;
2410
+ constexpr data_layout input_layout = get_input_data_layout();
2411
+ typedef tile<16, 8, int, input_layout> tile_A;
2412
+ typedef tile<16, 8, int, input_layout> tile_B;
2413
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
2414
+ typedef tile<64, 2, int, input_layout> tile_load;
2156
2415
 
2157
2416
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
2158
2417
  constexpr int rows_per_warp = granularity;
@@ -2190,6 +2449,57 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2190
2449
  tile_C C;
2191
2450
  mma(C, A[n], B[0]);
2192
2451
 
2452
+ #pragma unroll
2453
+ for (int l = 0; l < tile_C::ne; ++l) {
2454
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2455
+ const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2456
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2457
+ }
2458
+ }
2459
+ }
2460
+ }
2461
+ #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
2462
+ constexpr data_layout input_layout = get_input_data_layout();
2463
+ typedef tile<16, 4, int, input_layout> tile_A;
2464
+ typedef tile<16, 4, int, input_layout> tile_B;
2465
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
2466
+
2467
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
2468
+ constexpr int rows_per_warp = granularity;
2469
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2470
+
2471
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2472
+
2473
+ const int * x_qs = (const int *) x;
2474
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2475
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2476
+ const int * y_qs = (const int *) y + 4;
2477
+ const float * y_df = (const float *) y;
2478
+
2479
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
2480
+
2481
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2482
+ const int k0 = k00 + k01;
2483
+
2484
+ tile_A A[ntx];
2485
+ #pragma unroll
2486
+ for (int n = 0; n < ntx; ++n) {
2487
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2488
+ }
2489
+
2490
+ #pragma unroll
2491
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2492
+ tile_B B;
2493
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2494
+
2495
+ const int j = j0 + tile_C::get_j(0);
2496
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
2497
+
2498
+ #pragma unroll
2499
+ for (int n = 0; n < ntx; ++n) {
2500
+ tile_C C;
2501
+ mma(C, A[n], B);
2502
+
2193
2503
  #pragma unroll
2194
2504
  for (int l = 0; l < tile_C::ne; ++l) {
2195
2505
  const int i = i0 + n*tile_C::I + tile_C::get_i(l);
@@ -2303,7 +2613,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2303
2613
  #else
2304
2614
  GGML_UNUSED_VARS(x, y, sum, k00);
2305
2615
  NO_DEVICE_CODE;
2306
- #endif // AMD_MFMA_AVAILABLE
2616
+ #endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
2307
2617
  }
2308
2618
 
2309
2619
  template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
@@ -2311,14 +2621,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2311
2621
  constexpr int nwarps = mmq_get_nwarps_device();
2312
2622
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2313
2623
 
2314
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2624
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2315
2625
  int * x_qs = (int *) x_tile;
2316
2626
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2317
2627
  #else
2318
2628
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
2319
2629
  int * x_qs = (int *) x_tile;
2320
2630
  float * x_df = (float *) (x_qs + txs.qs);
2321
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2631
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2322
2632
 
2323
2633
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
2324
2634
  constexpr int nrows = warp_size / threads_per_row;
@@ -2340,13 +2650,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2340
2650
  const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2341
2651
  const int k0 = kbx * (2 * QI4_NL) + kqsx;
2342
2652
 
2343
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2653
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2344
2654
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2345
2655
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
2346
2656
  #else
2347
2657
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2348
2658
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
2349
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2659
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2350
2660
  }
2351
2661
 
2352
2662
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
@@ -2363,11 +2673,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2363
2673
 
2364
2674
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
2365
2675
 
2366
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2676
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2367
2677
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
2368
2678
  #else
2369
2679
  x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
2370
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2680
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2371
2681
  }
2372
2682
  }
2373
2683
 
@@ -2376,14 +2686,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2376
2686
  constexpr int nwarps = mmq_get_nwarps_device();
2377
2687
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2378
2688
 
2379
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2689
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2380
2690
  int * x_qs = (int *) x_tile;
2381
2691
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2382
2692
  #else
2383
2693
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
2384
2694
  int * x_qs = (int *) x_tile;
2385
2695
  float * x_df = (float *) (x_qs + txs.qs);
2386
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2696
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2387
2697
 
2388
2698
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
2389
2699
  constexpr int nrows = warp_size / threads_per_row;
@@ -2414,22 +2724,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2414
2724
  const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
2415
2725
  const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
2416
2726
 
2417
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2727
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2418
2728
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
2419
2729
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
2420
2730
  #else
2421
2731
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
2422
2732
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
2423
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2733
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2424
2734
  }
2425
2735
 
2426
2736
  const int ls = aux32 >> 28;
2427
2737
  const float d = bxi->d;
2428
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2738
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2429
2739
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
2430
2740
  #else
2431
2741
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
2432
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2742
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2433
2743
  }
2434
2744
  }
2435
2745
 
@@ -2438,14 +2748,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2438
2748
  constexpr int nwarps = mmq_get_nwarps_device();
2439
2749
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2440
2750
 
2441
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2751
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2442
2752
  int * x_qs = (int *) x_tile;
2443
2753
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2444
2754
  #else
2445
2755
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
2446
2756
  int * x_qs = (int *) x_tile;
2447
2757
  float * x_df = (float *) (x_qs + txs.qs);
2448
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2758
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2449
2759
 
2450
2760
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
2451
2761
  constexpr int nrows = warp_size / threads_per_row;
@@ -2472,24 +2782,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2472
2782
  const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
2473
2783
  const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
2474
2784
 
2475
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2785
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2476
2786
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2477
2787
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2478
2788
  #else
2479
2789
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2480
2790
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2481
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2791
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2482
2792
  }
2483
2793
 
2484
2794
  const int ls = bxi->scales[kqsx];
2485
2795
  const float d = bxi->d;
2486
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2796
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2487
2797
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2488
2798
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2489
2799
  #else
2490
2800
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2491
2801
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2492
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2802
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2493
2803
  }
2494
2804
  }
2495
2805
 
@@ -2498,15 +2808,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2498
2808
  constexpr int nwarps = mmq_get_nwarps_device();
2499
2809
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2500
2810
 
2501
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2811
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2502
2812
  int * x_qs = (int *) x_tile;
2503
2813
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2504
2814
  #else
2505
2815
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
2506
2816
  int * x_qs = (int *) x_tile;
2507
2817
  float * x_df = (float *) (x_qs + txs.qs);
2508
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2509
-
2818
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2510
2819
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
2511
2820
  constexpr int nrows = warp_size / threads_per_row;
2512
2821
  const int kqsx = threadIdx.x % threads_per_row;
@@ -2539,24 +2848,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2539
2848
  const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
2540
2849
  const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
2541
2850
 
2542
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2851
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2543
2852
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2544
2853
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2545
2854
  #else
2546
2855
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2547
2856
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2548
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2857
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2549
2858
  }
2550
2859
 
2551
2860
  const int ls = bxi->scales[kqsx];
2552
2861
  const float d = bxi->d;
2553
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2862
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2554
2863
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2555
2864
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2556
2865
  #else
2557
2866
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2558
2867
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2559
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2868
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2560
2869
  }
2561
2870
  }
2562
2871
 
@@ -2565,14 +2874,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2565
2874
  constexpr int nwarps = mmq_get_nwarps_device();
2566
2875
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2567
2876
 
2568
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2877
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2569
2878
  int * x_qs = (int *) x_tile;
2570
2879
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2571
2880
  #else
2572
2881
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
2573
2882
  int * x_qs = (int *) x_tile;
2574
2883
  float * x_df = (float *) (x_qs + txs.qs);
2575
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2884
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2576
2885
 
2577
2886
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
2578
2887
  constexpr int nrows = warp_size / threads_per_row;
@@ -2601,22 +2910,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2601
2910
  const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
2602
2911
  const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
2603
2912
 
2604
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2913
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2605
2914
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
2606
2915
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
2607
2916
  #else
2608
2917
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2609
2918
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2610
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2919
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2611
2920
  }
2612
2921
 
2613
2922
  const int ls = aux32 >> 28;
2614
2923
  const float d = bxi->d;
2615
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2924
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2616
2925
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2617
2926
  #else
2618
2927
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2619
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2928
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2620
2929
  }
2621
2930
  }
2622
2931
 
@@ -2625,14 +2934,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2625
2934
  constexpr int nwarps = mmq_get_nwarps_device();
2626
2935
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2627
2936
 
2628
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2937
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2629
2938
  int * x_qs = (int *) x_tile;
2630
2939
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2631
2940
  #else
2632
2941
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2633
2942
  int * x_qs = (int *) x_tile;
2634
2943
  float * x_df = (float *) (x_qs + txs.qs);
2635
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2944
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2636
2945
 
2637
2946
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
2638
2947
  constexpr int nrows = warp_size / threads_per_row;
@@ -2668,22 +2977,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2668
2977
  const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2669
2978
  const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2670
2979
 
2671
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2980
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2672
2981
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
2673
2982
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
2674
2983
  #else
2675
2984
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
2676
2985
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
2677
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2986
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2678
2987
  }
2679
2988
 
2680
2989
  const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
2681
2990
  const float d = bxi->d;
2682
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2991
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2683
2992
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2684
2993
  #else
2685
2994
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
2686
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2995
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2687
2996
  }
2688
2997
  }
2689
2998
 
@@ -2692,14 +3001,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2692
3001
  constexpr int nwarps = mmq_get_nwarps_device();
2693
3002
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2694
3003
 
2695
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3004
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2696
3005
  int * x_qs = (int *) x_tile;
2697
3006
  half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
2698
3007
  #else
2699
3008
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2700
3009
  int * x_qs = (int *) x_tile;
2701
3010
  half2 * x_ds = (half2 *) (x_qs + txs.qs);
2702
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3011
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2703
3012
 
2704
3013
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
2705
3014
  constexpr int nrows = warp_size / threads_per_row;
@@ -2727,23 +3036,23 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2727
3036
  const int grid0 = (grid >> 0) & 0x0F0F0F0F;
2728
3037
  const int grid1 = (grid >> 4) & 0x0F0F0F0F;
2729
3038
 
2730
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3039
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2731
3040
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
2732
3041
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
2733
3042
  #else
2734
3043
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
2735
3044
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
2736
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3045
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2737
3046
  }
2738
3047
 
2739
3048
  const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
2740
3049
  const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
2741
3050
 
2742
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3051
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2743
3052
  x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
2744
3053
  #else
2745
3054
  x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
2746
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3055
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2747
3056
  }
2748
3057
  }
2749
3058
 
@@ -2752,14 +3061,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2752
3061
  constexpr int nwarps = mmq_get_nwarps_device();
2753
3062
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2754
3063
 
2755
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3064
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2756
3065
  int * x_qs = (int *) x_tile;
2757
3066
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2758
3067
  #else
2759
3068
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2760
3069
  int * x_qs = (int *) x_tile;
2761
3070
  float * x_df = (float *) (x_qs + txs.qs);
2762
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3071
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2763
3072
 
2764
3073
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
2765
3074
  constexpr int nrows = warp_size / threads_per_row;
@@ -2779,13 +3088,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2779
3088
  const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2780
3089
  const int k0 = 8 * (kqsx / 4) + kqsx % 4;
2781
3090
 
2782
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3091
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2783
3092
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2784
3093
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2785
3094
  #else
2786
3095
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2787
3096
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
2788
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3097
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2789
3098
  }
2790
3099
 
2791
3100
  constexpr int rows_per_warp = warp_size / 8;
@@ -2804,11 +3113,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2804
3113
  const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
2805
3114
  | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
2806
3115
 
2807
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3116
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2808
3117
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
2809
3118
  #else
2810
3119
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2811
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3120
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2812
3121
  }
2813
3122
  }
2814
3123
 
@@ -2848,9 +3157,9 @@ static __device__ __forceinline__ void mmq_write_back_mma(
2848
3157
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
2849
3158
  constexpr int nwarps = mmq_get_nwarps_device();
2850
3159
 
2851
- #if defined(AMD_MFMA_AVAILABLE)
3160
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2852
3161
  constexpr int tileC_IJ = mmq_get_granularity_device(0);
2853
- typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
3162
+ typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;
2854
3163
  constexpr int rows_per_warp = granularity;
2855
3164
  #else
2856
3165
  typedef tile<16, 8, int> tile_C;
@@ -2859,11 +3168,11 @@ static __device__ __forceinline__ void mmq_write_back_mma(
2859
3168
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2860
3169
 
2861
3170
  const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
2862
- #if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
3171
+ #if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2863
3172
  static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
2864
3173
  #else
2865
3174
  GGML_UNUSED(nwarps);
2866
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3175
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2867
3176
 
2868
3177
  #pragma unroll
2869
3178
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
@@ -2937,8 +3246,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
2937
3246
  template <int mmq_x, int mmq_y, bool need_check>
2938
3247
  struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
2939
3248
  static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
3249
+ #ifdef BLACKWELL_MMA_AVAILABLE
3250
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
3251
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
3252
+ #else
2940
3253
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
2941
3254
  static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3255
+ #endif // BLACKWELL_MMA_AVAILABLE
2942
3256
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2943
3257
  };
2944
3258
 
@@ -3063,25 +3377,34 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
3063
3377
  int * tile_y = data_mul_mat_q + mmq_x;
3064
3378
  int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
3065
3379
 
3066
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3380
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3067
3381
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
3068
3382
  constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
3069
3383
  #else
3070
3384
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
3071
3385
  constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
3072
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3386
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3073
3387
 
3074
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3388
+ #if defined(BLACKWELL_MMA_AVAILABLE)
3389
+ // FP4 tile stores 8 blocks
3390
+ constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
3391
+ #else
3392
+ constexpr int ne_block = 4 * QK8_1;
3393
+ #endif // defined(BLACKWELL_MMA_AVAILABLE)
3394
+
3395
+ constexpr int ITER_K = get_iter_k(type);
3396
+ constexpr int blocks_per_iter = ITER_K / qk;
3075
3397
 
3076
3398
  float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
3077
3399
 
3400
+ constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
3401
+
3078
3402
  for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
3079
3403
  load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
3080
-
3081
3404
  {
3082
- const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
3405
+ const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
3083
3406
  #pragma unroll
3084
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
3407
+ for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
3085
3408
  int l = l0 + threadIdx.y*warp_size + threadIdx.x;
3086
3409
 
3087
3410
  tile_y[l] = by0[l];
@@ -3095,9 +3418,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
3095
3418
  __syncthreads();
3096
3419
 
3097
3420
  {
3098
- const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
3421
+ const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
3099
3422
  #pragma unroll
3100
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
3423
+ for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
3101
3424
  int l = l0 + threadIdx.y*warp_size + threadIdx.x;
3102
3425
 
3103
3426
  tile_y[l] = by0[l];
@@ -3229,8 +3552,10 @@ static __global__ void mul_mat_q(
3229
3552
  }
3230
3553
  #endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3231
3554
 
3555
+ constexpr int ITER_K = get_iter_k(type);
3556
+
3232
3557
  const int64_t blocks_per_ne00 = ncols_x / qk;
3233
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3558
+ constexpr int blocks_per_iter = ITER_K / qk;
3234
3559
 
3235
3560
  // kbc == k block continuous, current index in continuous ijk space.
3236
3561
  int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
@@ -3291,7 +3616,7 @@ static __global__ void mul_mat_q(
3291
3616
  __syncthreads();
3292
3617
  }
3293
3618
 
3294
- offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
3619
+ offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
3295
3620
  offset_dst += it*mmq_y;
3296
3621
 
3297
3622
  const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@@ -3358,7 +3683,7 @@ static __global__ void mul_mat_q(
3358
3683
  __syncthreads();
3359
3684
  }
3360
3685
 
3361
- offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
3686
+ offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
3362
3687
  offset_dst += it*mmq_y;
3363
3688
 
3364
3689
  const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@@ -3381,7 +3706,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
3381
3706
  const int ncols_max) {
3382
3707
  constexpr int mmq_y = get_mmq_y_device();
3383
3708
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
3384
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3709
+ constexpr int ITER_K = get_iter_k(type);
3710
+
3711
+ constexpr int blocks_per_iter = ITER_K / qk;
3385
3712
  const int64_t blocks_per_ne00 = ncols_x / qk;
3386
3713
 
3387
3714
  constexpr int nwarps = mmq_get_nwarps_device();
@@ -3494,7 +3821,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
3494
3821
  const int col_diff = col_high - col_low;
3495
3822
 
3496
3823
  for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
3497
- ids_dst_shared[j] = ids_dst[col_low + j];
3824
+ ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
3498
3825
  }
3499
3826
  __syncthreads();
3500
3827
 
@@ -3538,8 +3865,8 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
3538
3865
  const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
3539
3866
  const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
3540
3867
  const size_t nbs_ids = mmq_x*sizeof(int);
3541
- const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3542
- const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
3868
+ const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3869
+ const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
3543
3870
  return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
3544
3871
  }
3545
3872
 
@@ -3755,4 +4082,4 @@ void ggml_cuda_op_mul_mat_q(
3755
4082
  const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
3756
4083
  const int64_t src1_padded_row_size, cudaStream_t stream);
3757
4084
 
3758
- bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);
4085
+ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);