whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
11
11
 
12
12
  #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
13
13
  #define MMQ_ITER_K 256
14
+ #define MMQ_ITER_K_MXFP4_FP4 512
14
15
  #define MMQ_NWARPS 8
15
16
 
16
17
  typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
@@ -44,8 +45,15 @@ struct block_q8_1_mmq {
44
45
  };
45
46
  int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
46
47
  };
48
+
49
+ struct block_fp4_mmq {
50
+ uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
51
+ int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
52
+ };
53
+
47
54
  static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
48
55
  static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
56
+ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
49
57
 
50
58
  static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
51
59
  switch (type_x) {
@@ -58,6 +66,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
58
66
  return MMQ_Q8_1_DS_LAYOUT_DS4;
59
67
  case GGML_TYPE_Q8_0:
60
68
  return MMQ_Q8_1_DS_LAYOUT_D4;
69
+ case GGML_TYPE_MXFP4:
70
+ return MMQ_Q8_1_DS_LAYOUT_D4;
61
71
  case GGML_TYPE_Q2_K:
62
72
  return MMQ_Q8_1_DS_LAYOUT_D2S6;
63
73
  case GGML_TYPE_Q3_K:
@@ -90,7 +100,7 @@ struct tile_x_sizes {
90
100
  };
91
101
 
92
102
  static int get_mmq_x_max_host(const int cc) {
93
- return new_mma_available(cc) ? 128 :
103
+ return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
94
104
  GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
95
105
  #ifdef GGML_CUDA_FORCE_MMQ
96
106
  128 : 64;
@@ -100,13 +110,13 @@ static int get_mmq_x_max_host(const int cc) {
100
110
  }
101
111
 
102
112
  static constexpr __device__ int get_mmq_x_max_device() {
103
- #ifdef NEW_MMA_AVAILABLE
113
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
104
114
  return 128;
105
- #else // NEW_MMA_AVAILABLE
115
+ #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
106
116
 
107
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
108
- return 128;
109
- #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
117
+ #if defined(GGML_USE_HIP)
118
+ return 64;
119
+ #else // defined(GGML_USE_HIP)
110
120
 
111
121
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
112
122
  #ifdef GGML_CUDA_FORCE_MMQ
@@ -115,12 +125,11 @@ static constexpr __device__ int get_mmq_x_max_device() {
115
125
  return MMQ_DP4A_MAX_BATCH_SIZE;
116
126
  #endif // GGML_CUDA_FORCE_MMQ
117
127
  #else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
118
-
119
128
  return 64;
120
129
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
121
130
 
122
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
123
- #endif // NEW_MMA_AVAILABLE
131
+ #endif // defined(GGML_USE_HIP)
132
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
124
133
  }
125
134
 
126
135
  static int get_mmq_y_host(const int cc) {
@@ -128,8 +137,16 @@ static int get_mmq_y_host(const int cc) {
128
137
  ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
129
138
  }
130
139
 
140
+ static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
141
+ #if defined(BLACKWELL_MMA_AVAILABLE)
142
+ return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
143
+ #else
144
+ return MMQ_ITER_K;
145
+ #endif // defined(BLACKWELL_MMA_AVAILABLE)
146
+ }
147
+
131
148
  static constexpr __device__ int get_mmq_y_device() {
132
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
149
+ #if defined(GGML_USE_HIP)
133
150
  #if defined(RDNA1)
134
151
  return 64;
135
152
  #else
@@ -141,19 +158,28 @@ static constexpr __device__ int get_mmq_y_device() {
141
158
  #else
142
159
  return 64;
143
160
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
144
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
161
+ #endif // defined(GGML_USE_HIP)
145
162
  }
146
163
 
147
- #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
148
- #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
149
- #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
150
- #define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0}
151
- #define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
152
- #define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
153
- #define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8}
154
- #define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
155
- #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
156
- #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
164
+ // Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
165
+ // The K dimension of the tiles has either,
166
+ // 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
167
+ // 32 bit elements for the quantized data (does not include scales).
168
+ // In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
169
+ // The final tile size in K direction is padded to avoid shared memory bank conflicts,
170
+ // in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
171
+ #define MMQ_TILE_NE_K 32
172
+
173
+ #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0}
174
+ #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0}
175
+ #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}
176
+ #define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0}
177
+ #define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0}
178
+ #define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0}
179
+ #define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
180
+ #define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
181
+ #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
182
+ #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
157
183
 
158
184
  static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
159
185
  switch (type) {
@@ -162,6 +188,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
162
188
  case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
163
189
  case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
164
190
  case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
191
+ case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
165
192
  case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
166
193
  case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
167
194
  case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
@@ -179,17 +206,20 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
179
206
  }
180
207
  }
181
208
 
182
- #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
183
- #define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
184
- #define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4)
185
- #define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4)
186
- #define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
209
+ #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
210
+ #define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
211
+ #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
212
+ #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
213
+ #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
214
+ #define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
187
215
 
188
216
  static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
189
217
  static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
190
218
  static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
191
219
  static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
192
220
  static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
221
+ static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
222
+ static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
193
223
 
194
224
  static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
195
225
  switch (type) {
@@ -198,6 +228,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
198
228
  case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
199
229
  case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
200
230
  case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
231
+ // tile sizes are the same for Q8_1 and FP4 for blackwell
232
+ case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
201
233
  case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
202
234
  case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
203
235
  case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@@ -215,42 +247,77 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
215
247
  }
216
248
  }
217
249
 
218
- #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
250
+ // block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
251
+ #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
252
+ #define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
219
253
 
220
254
  static int mmq_get_granularity_host(const int mmq_x, const int cc) {
221
- return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
255
+ if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
256
+ return mmq_x >= 128 ? 32 : 16;
257
+ } else if (turing_mma_available(cc) && mmq_x >= 48) {
258
+ return 16;
259
+ } else {
260
+ return 8;
261
+ }
222
262
  }
223
263
 
224
- #ifdef NEW_MMA_AVAILABLE
264
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
265
+ static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
266
+ return mmq_x >= 128 ? 32 : 16;
267
+ }
268
+ #elif defined(TURING_MMA_AVAILABLE)
225
269
  static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
226
270
  return mmq_x >= 48 ? 16 : 8;
227
271
  }
228
272
  #else
229
- static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
273
+ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) {
230
274
  return 8;
231
275
  }
232
- #endif // NEW_MMA_AVAILABLE
276
+ #endif // AMD_MFMA_AVAILABLE
277
+
278
+ #if defined(GGML_USE_HIP)
279
+ static int mmq_get_nwarps_host(const int cc, const int warp_size) {
280
+ return amd_mfma_available(cc) ? 8 : 256/warp_size;
281
+ }
282
+ #else
283
+ static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
284
+ return 256/warp_size;
285
+ }
286
+ #endif // (GGML_USE_HIP)
287
+
288
+ static constexpr __device__ int mmq_get_nwarps_device() {
289
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
290
+ return 8;
291
+ #else
292
+ return 256/ggml_cuda_get_physical_warp_size();
293
+ #endif // AMD_MFMA_AVAILABLE
294
+ }
233
295
 
234
296
  // ------------------------------------------------------------
235
297
 
236
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
298
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
237
299
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
300
+ constexpr int nwarps = mmq_get_nwarps_device();
301
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
238
302
 
239
- #ifdef NEW_MMA_AVAILABLE
303
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
240
304
  int * x_qs = (int *) x_tile;
241
- float * x_df = (float *) (x_qs + 2*WARP_SIZE);
305
+ float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
242
306
  #else
243
307
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
244
308
  int * x_qs = (int *) x_tile;
245
309
  float * x_df = (float *) (x_qs + txs.qs);
246
- #endif // NEW_MMA_AVAILABLE
310
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
247
311
 
248
- const int kbx = threadIdx.x / QI4_0;
249
- const int kqsx = threadIdx.x % QI4_0;
312
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
313
+ constexpr int nrows = warp_size / threads_per_row;
314
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
315
+ const int kbx = txi / QI4_0;
316
+ const int kqsx = txi % QI4_0;
250
317
 
251
318
  #pragma unroll
252
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
253
- int i = i0 + threadIdx.y;
319
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
320
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
254
321
 
255
322
  if (need_check) {
256
323
  i = min(i, i_max);
@@ -259,20 +326,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
259
326
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
260
327
  const int qs0 = get_int_b2(bxi->qs, kqsx);
261
328
 
262
- #ifdef NEW_MMA_AVAILABLE
329
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
263
330
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
264
331
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
265
332
  #else
266
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
267
- #endif // NEW_MMA_AVAILABLE
333
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
334
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
268
335
  }
269
336
 
270
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
337
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
338
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
271
339
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
272
340
 
273
341
  #pragma unroll
274
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
275
- int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
342
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
343
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
276
344
 
277
345
  if (need_check) {
278
346
  i = min(i, i_max);
@@ -280,17 +348,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
280
348
 
281
349
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
282
350
 
283
- #ifdef NEW_MMA_AVAILABLE
284
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
351
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
352
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
285
353
  #else
286
- x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
287
- #endif // NEW_MMA_AVAILABLE
354
+ x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
355
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
288
356
  }
289
357
  }
290
358
 
291
- template <int mmq_x, int mmq_y, int nwarps>
359
+ template <int mmq_x, int mmq_y>
292
360
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
293
361
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
362
+ constexpr int nwarps = mmq_get_nwarps_device();
363
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
294
364
 
295
365
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
296
366
  const int * x_qs = (const int *) x;
@@ -299,7 +369,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
299
369
  const half2 * y_ds = (const half2 *) y;
300
370
 
301
371
  // #pragma unroll
302
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
372
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
303
373
  const int k0 = k00 + k01;
304
374
 
305
375
  #pragma unroll
@@ -307,7 +377,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
307
377
  const int j = j0 + threadIdx.y;
308
378
 
309
379
  #pragma unroll
310
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
380
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
311
381
  const int i = i0 + threadIdx.x;
312
382
 
313
383
  const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
@@ -320,32 +390,37 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
320
390
  u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
321
391
  }
322
392
 
323
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
324
- (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u,
325
- x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
393
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
394
+ (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
395
+ x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
326
396
  }
327
397
  }
328
398
  }
329
399
  }
330
400
 
331
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
401
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
332
402
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
403
+ constexpr int nwarps = mmq_get_nwarps_device();
404
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
333
405
 
334
- #ifdef NEW_MMA_AVAILABLE
406
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
335
407
  int * x_qs = (int *) x_tile;
336
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
408
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
337
409
  #else
338
410
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
339
411
  int * x_qs = (int *) x_tile;
340
412
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
341
- #endif // NEW_MMA_AVAILABLE
413
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
342
414
 
343
- const int kbx = threadIdx.x / QI4_1;
344
- const int kqsx = threadIdx.x % QI4_1;
415
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
416
+ constexpr int nrows = warp_size / threads_per_row;
417
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
418
+ const int kbx = txi / QI4_1;
419
+ const int kqsx = txi % QI4_1;
345
420
 
346
421
  #pragma unroll
347
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
348
- int i = i0 + threadIdx.y;
422
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
423
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
349
424
 
350
425
  if (need_check) {
351
426
  i = min(i, i_max);
@@ -354,20 +429,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
354
429
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
355
430
  const int qs0 = get_int_b4(bxi->qs, kqsx);
356
431
 
357
- #ifdef NEW_MMA_AVAILABLE
432
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
358
433
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
359
434
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
360
435
  #else
361
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
362
- #endif // NEW_MMA_AVAILABLE
436
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
437
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
363
438
  }
364
439
 
365
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
440
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
441
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
366
442
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
367
443
 
368
444
  #pragma unroll
369
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
370
- int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
445
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
446
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
371
447
 
372
448
  if (need_check) {
373
449
  i = min(i, i_max);
@@ -375,17 +451,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
375
451
 
376
452
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
377
453
 
378
- #ifdef NEW_MMA_AVAILABLE
379
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
454
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
455
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
380
456
  #else
381
- x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
382
- #endif // NEW_MMA_AVAILABLE
457
+ x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
458
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
383
459
  }
384
460
  }
385
461
 
386
- template <int mmq_x, int mmq_y, int nwarps>
462
+ template <int mmq_x, int mmq_y>
387
463
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
388
464
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
465
+ constexpr int nwarps = mmq_get_nwarps_device();
466
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
389
467
 
390
468
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
391
469
  const int * x_qs = (const int *) x;
@@ -394,7 +472,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
394
472
  const half2 * y_ds = (const half2 *) y;
395
473
 
396
474
  // #pragma unroll
397
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
475
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
398
476
  const int k0 = k00 + k01;
399
477
 
400
478
  #pragma unroll
@@ -402,7 +480,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
402
480
  const int j = j0 + threadIdx.y;
403
481
 
404
482
  #pragma unroll
405
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
483
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
406
484
  const int i = i0 + threadIdx.x;
407
485
 
408
486
  const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
@@ -415,32 +493,37 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
415
493
  u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
416
494
  }
417
495
 
418
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
419
- (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u,
420
- x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
496
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
497
+ (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
498
+ x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
421
499
  }
422
500
  }
423
501
  }
424
502
  }
425
503
 
426
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
504
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
427
505
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
506
+ constexpr int nwarps = mmq_get_nwarps_device();
507
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
428
508
 
429
- #ifdef NEW_MMA_AVAILABLE
509
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
430
510
  int * x_qs = (int *) x_tile;
431
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
511
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
432
512
  #else
433
513
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
434
514
  int * x_qs = (int *) x_tile;
435
515
  float * x_df = (float *) (x_qs + txs.qs);
436
- #endif // NEW_MMA_AVAILABLE
516
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
437
517
 
438
- const int kbx = threadIdx.x / QI5_0;
439
- const int kqsx = threadIdx.x % QI5_0;
518
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
519
+ constexpr int nrows = warp_size / threads_per_row;
520
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
521
+ const int kbx = txi / QI5_0;
522
+ const int kqsx = txi % QI5_0;
440
523
 
441
524
  #pragma unroll
442
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
443
- int i = i0 + threadIdx.y;
525
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
526
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
444
527
 
445
528
  if (need_check) {
446
529
  i = min(i, i_max);
@@ -449,7 +532,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
449
532
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
450
533
 
451
534
  const int ql = get_int_b2(bxi->qs, kqsx);
452
- const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
535
+ const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx);
453
536
 
454
537
  int qs0 = (ql >> 0) & 0x0F0F0F0F;
455
538
  qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
@@ -465,21 +548,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
465
548
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
466
549
  qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
467
550
 
468
- #ifdef NEW_MMA_AVAILABLE
551
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
469
552
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
470
553
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
471
554
  #else
472
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
473
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
474
- #endif // NEW_MMA_AVAILABLE
555
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
556
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
557
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
475
558
  }
476
559
 
477
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
560
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
561
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
478
562
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
479
563
 
480
564
  #pragma unroll
481
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
482
- int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
565
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
566
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
483
567
 
484
568
  if (need_check) {
485
569
  i = min(i, i_max);
@@ -487,32 +571,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
487
571
 
488
572
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
489
573
 
490
- #ifdef NEW_MMA_AVAILABLE
491
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
574
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
575
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
492
576
  #else
493
- x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
494
- #endif // NEW_MMA_AVAILABLE
577
+ x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
578
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
495
579
  }
496
580
  }
497
581
 
498
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
582
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
499
583
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
584
+ constexpr int nwarps = mmq_get_nwarps_device();
585
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
500
586
 
501
- #ifdef NEW_MMA_AVAILABLE
587
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
502
588
  int * x_qs = (int *) x_tile;
503
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
589
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
504
590
  #else
505
591
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
506
592
  int * x_qs = (int *) x_tile;
507
593
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
508
- #endif // NEW_MMA_AVAILABLE
594
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
509
595
 
510
- const int kbx = threadIdx.x / QI5_1;
511
- const int kqsx = threadIdx.x % QI5_1;
596
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
597
+ constexpr int nrows = warp_size / threads_per_row;
598
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
599
+ const int kbx = txi / QI5_1;
600
+ const int kqsx = txi % QI5_1;
512
601
 
513
602
  #pragma unroll
514
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
515
- int i = i0 + threadIdx.y;
603
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
604
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
516
605
 
517
606
  if (need_check) {
518
607
  i = min(i, i_max);
@@ -521,7 +610,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
521
610
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
522
611
 
523
612
  const int ql = get_int_b4(bxi->qs, kqsx);
524
- const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
613
+ const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx);
525
614
 
526
615
  int qs0 = (ql >> 0) & 0x0F0F0F0F;
527
616
  qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
@@ -535,21 +624,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
535
624
  qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
536
625
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
537
626
 
538
- #ifdef NEW_MMA_AVAILABLE
627
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
539
628
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
540
629
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
541
630
  #else
542
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
543
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
544
- #endif // NEW_MMA_AVAILABLE
631
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
632
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
633
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
545
634
  }
546
635
 
547
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
636
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
637
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
548
638
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
549
639
 
550
640
  #pragma unroll
551
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
552
- int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
641
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
642
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
553
643
 
554
644
  if (need_check) {
555
645
  i = min(i, i_max);
@@ -557,32 +647,38 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
557
647
 
558
648
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
559
649
 
560
- #ifdef NEW_MMA_AVAILABLE
561
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
650
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
651
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
562
652
  #else
563
- x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
564
- #endif // NEW_MMA_AVAILABLE
653
+ x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
654
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
565
655
  }
566
656
  }
567
657
 
568
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
658
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
569
659
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
660
+ constexpr int nwarps = mmq_get_nwarps_device();
661
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
570
662
 
571
- #ifdef NEW_MMA_AVAILABLE
663
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
572
664
  int * x_qs = (int *) x_tile;
573
- float * x_df = (float *) (x_tile + 2*WARP_SIZE);
665
+ float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
574
666
  #else
575
667
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
576
668
  int * x_qs = (int *) x_tile;
577
669
  float * x_df = (float *) (x_qs + txs.qs);
578
- #endif // NEW_MMA_AVAILABLE
670
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
579
671
 
580
- const int kbx = threadIdx.x / QI8_0;
581
- const int kqsx = threadIdx.x % QI8_0;
672
+ // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
673
+ constexpr int threads_per_row = 32;
674
+ constexpr int nrows = warp_size / threads_per_row;
675
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
676
+ const int kbx = txi / QI8_0;
677
+ const int kqsx = txi % QI8_0;
582
678
 
583
679
  #pragma unroll
584
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
585
- int i = i0 + threadIdx.y;
680
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
681
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
586
682
 
587
683
  if (need_check) {
588
684
  i = min(i, i_max);
@@ -590,21 +686,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
590
686
 
591
687
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
592
688
 
593
- #ifdef NEW_MMA_AVAILABLE
594
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
595
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
689
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
690
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
691
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
596
692
  #else
597
- x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
598
- x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
599
- #endif // NEW_MMA_AVAILABLE
693
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
694
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
695
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
600
696
  }
601
697
 
602
- const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
698
+ constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
699
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
603
700
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
604
701
 
605
702
  #pragma unroll
606
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) {
607
- int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row;
703
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
704
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
608
705
 
609
706
  if (need_check) {
610
707
  i = min(i, i_max);
@@ -612,17 +709,128 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
612
709
 
613
710
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
614
711
 
615
- #ifdef NEW_MMA_AVAILABLE
616
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
712
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
713
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
617
714
  #else
618
- x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
619
- #endif // NEW_MMA_AVAILABLE
715
+ x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
716
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
620
717
  }
621
718
  }
622
719
 
623
- template <int mmq_x, int mmq_y, int nwarps>
720
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
721
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
722
+ constexpr int nwarps = mmq_get_nwarps_device();
723
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
724
+
725
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
726
+ int * x_qs = (int *) x_tile;
727
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
728
+ #else
729
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
730
+ int * x_qs = (int *) x_tile;
731
+ float * x_df = (float *) (x_qs + txs.qs);
732
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
733
+
734
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
735
+ constexpr int nrows = warp_size / threads_per_row;
736
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
737
+ const int kbx = txi / QI_MXFP4;
738
+ const int kqsx = txi % QI_MXFP4;
739
+
740
+ #pragma unroll
741
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
742
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
743
+
744
+ if (need_check) {
745
+ i = min(i, i_max);
746
+ }
747
+
748
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
749
+
750
+ const int aux_q4 = get_int_b1(bxi->qs, kqsx);
751
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
752
+ const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
753
+
754
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
755
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
756
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
757
+ #else
758
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
759
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
760
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
761
+ }
762
+
763
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
764
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
765
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
766
+
767
+ #pragma unroll
768
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
769
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
770
+
771
+ if (need_check) {
772
+ i = min(i, i_max);
773
+ }
774
+
775
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
776
+
777
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
778
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
779
+ #else
780
+ x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
781
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
782
+ }
783
+ }
784
+
785
+ template <int mmq_y, bool need_check>
786
+ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
787
+ int * __restrict__ x_tile,
788
+ const int kbx0,
789
+ const int i_max,
790
+ const int stride) {
791
+ constexpr int nwarps = mmq_get_nwarps_device();
792
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
793
+
794
+ int * x_qs = (int *) x_tile;
795
+ uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
796
+
797
+ const int txi = threadIdx.x;
798
+
799
+ constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
800
+
801
+ constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block
802
+ constexpr int rows_per_warp = warp_size / threads_per_row;
803
+ const int kbx = txi % threads_per_row;
804
+ const int row_in_warp = txi / threads_per_row;
805
+
806
+ #pragma unroll
807
+ for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
808
+ int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
809
+
810
+ if constexpr (need_check) {
811
+ i = min(i, i_max);
812
+ }
813
+
814
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
815
+
816
+ // quantize_mxfp4_mmq permutes nibbles to match the quantized format
817
+ const int k0 = kbx * 4;
818
+ memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
819
+
820
+ // Load E8M0 scales: pack 2 consecutive scales into one uint32
821
+ if (kbx % 2 == 0) {
822
+ uint32_t e = bxi->e;
823
+ e |= ((bxi + 1)->e << 8);
824
+ x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
825
+ }
826
+ }
827
+ }
828
+
829
+ template <int mmq_x, int mmq_y>
624
830
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
625
831
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
832
+ constexpr int nwarps = mmq_get_nwarps_device();
833
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
626
834
 
627
835
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
628
836
  const int * x_qs = (const int *) x;
@@ -631,7 +839,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
631
839
  const float * y_df = (const float *) y;
632
840
 
633
841
  // #pragma unroll
634
- for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
842
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
635
843
  const int k0 = k00 + k01;
636
844
 
637
845
  #pragma unroll
@@ -639,21 +847,77 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
639
847
  const int j = j0 + threadIdx.y;
640
848
 
641
849
  #pragma unroll
642
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
850
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
643
851
  const int i = i0 + threadIdx.x;
644
852
 
645
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
646
- (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE],
647
- x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]);
853
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
854
+ (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],
855
+ x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]);
648
856
  }
649
857
  }
650
858
  }
651
859
  }
652
860
 
653
- template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
861
+ template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
654
862
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
655
863
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
864
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
865
+ constexpr data_layout input_layout = get_input_data_layout();
866
+ typedef tile<16, 8, int, input_layout> tile_A;
867
+ typedef tile<16, 8, int, input_layout> tile_B;
868
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
869
+
870
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
871
+ constexpr int rows_per_warp = granularity;
872
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
873
+
874
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
875
+
876
+ const int * x_qs = (const int *) x;
877
+ const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
878
+ const int * y_qs = (const int *) y + 4;
879
+ const float * y_df = (const float *) y;
880
+ const half2 * y_ds = (const half2 *) y;
656
881
 
882
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
883
+
884
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
885
+ const int k0 = k00 + k01;
886
+
887
+ tile_A A[ntx];
888
+ #pragma unroll
889
+ for (int n = 0; n < ntx; ++n) {
890
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
891
+ }
892
+
893
+ #pragma unroll
894
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
895
+ tile_B B;
896
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
897
+
898
+ float dB;
899
+ const int j = j0 + tile_C::get_j(0);
900
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
901
+ dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
902
+ } else {
903
+ dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
904
+ }
905
+
906
+ #pragma unroll
907
+ for (int n = 0; n < ntx; ++n) {
908
+ tile_C C;
909
+ mma(C, A[n], B);
910
+
911
+ #pragma unroll
912
+ for (int l = 0; l < tile_C::ne; ++l) {
913
+ const int i = i0 + n*tile_A::I + tile_C::get_i(l);
914
+ const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
915
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;
916
+ }
917
+ }
918
+ }
919
+ }
920
+ #else
657
921
  typedef tile<16, 8, int> tile_A;
658
922
  typedef tile< 8, 8, int> tile_B;
659
923
  typedef tile<16, 8, int> tile_C;
@@ -662,23 +926,23 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
662
926
  constexpr int rows_per_warp = 2 * granularity;
663
927
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
664
928
 
665
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
929
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
666
930
 
667
931
  const int * x_qs = (const int *) x;
668
- const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
932
+ const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
669
933
  const int * y_qs = (const int *) y + 4;
670
934
  const float * y_df = (const float *) y;
671
935
  const half2 * y_ds = (const half2 *) y;
672
936
 
673
- tile_A A[ntx][WARP_SIZE/QI8_0];
674
- float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
937
+ tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];
938
+ float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];
675
939
 
676
940
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
677
941
 
678
942
  #pragma unroll
679
943
  for (int n = 0; n < ntx; ++n) {
680
944
  #pragma unroll
681
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
945
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
682
946
  const int k0 = k00 + k01;
683
947
 
684
948
  load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
@@ -689,7 +953,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
689
953
  const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
690
954
 
691
955
  #pragma unroll
692
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
956
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
693
957
  const int k0 = k00 + k01;
694
958
 
695
959
  dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
@@ -700,7 +964,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
700
964
  #pragma unroll
701
965
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
702
966
  #pragma unroll
703
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
967
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
704
968
  tile_B B;
705
969
  float dB[tile_C::ne/2];
706
970
 
@@ -729,11 +993,86 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
729
993
  }
730
994
  }
731
995
  }
996
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
732
997
  }
733
998
 
734
- template <int mmq_x, int mmq_y, int nwarps>
999
+ template <int mmq_x, int mmq_y>
1000
+ static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
1001
+ const int * __restrict__ y,
1002
+ float * __restrict__ sum,
1003
+ const int k00) {
1004
+ typedef tile<16, 8, int> tile_A;
1005
+ typedef tile<8, 8, int> tile_B;
1006
+ typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
1007
+
1008
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1009
+ constexpr int rows_per_warp = 2 * granularity;
1010
+ constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
1011
+
1012
+ y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
1013
+
1014
+ // Match layout from load_tiles_mxfp4_fp4
1015
+ const int * x_qs = (const int *) x;
1016
+ const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
1017
+ const int * y_qs = (const int *) y + 4;
1018
+ const uint32_t * y_sc = (const uint32_t *) y;
1019
+
1020
+ // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
1021
+ tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1022
+ uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1023
+
1024
+ // Block scale
1025
+ // Each thread has to point to a 4 byte scale value
1026
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1027
+
1028
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1029
+
1030
+ #pragma unroll
1031
+ for (int n = 0; n < ntx; ++n) {
1032
+ #pragma unroll
1033
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1034
+ const int k0 = k00 + k01;
1035
+
1036
+ load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
1037
+ MMQ_MMA_TILE_X_K_FP4);
1038
+
1039
+ // based on block-scaling document, 2 threads in each quad need to supply to the scale value
1040
+ const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
1041
+ scaleA[n][k01 / (2 * QI_MXFP4)] =
1042
+ *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
1043
+ }
1044
+ }
1045
+
1046
+ #pragma unroll
1047
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
1048
+ #pragma unroll
1049
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1050
+ tile_B B;
1051
+ uint32_t scaleB; // 2xN scales
1052
+
1053
+ load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
1054
+
1055
+ scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
1056
+
1057
+ #pragma unroll
1058
+ for (int n = 0; n < ntx; ++n) {
1059
+ tile_C C;
1060
+
1061
+ mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
1062
+ #pragma unroll
1063
+ for (int l = 0; l < tile_C::ne; ++l) {
1064
+ sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
1065
+ }
1066
+ }
1067
+ }
1068
+ }
1069
+ }
1070
+
1071
+ template <int mmq_x, int mmq_y>
735
1072
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
736
1073
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1074
+ constexpr int nwarps = mmq_get_nwarps_device();
1075
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
737
1076
 
738
1077
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
739
1078
  const int * x_qs = (const int *) x;
@@ -742,7 +1081,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
742
1081
  const half2 * y_ds = (const half2 *) y;
743
1082
 
744
1083
  // #pragma unroll
745
- for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
1084
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
746
1085
  const int k0 = k00 + k01;
747
1086
 
748
1087
  #pragma unroll
@@ -750,45 +1089,96 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
750
1089
  const int j = j0 + threadIdx.y;
751
1090
 
752
1091
  #pragma unroll
753
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1092
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
754
1093
  const int i = i0 + threadIdx.x;
755
1094
 
756
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
757
- (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
758
- x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1095
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
1096
+ (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1097
+ x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
759
1098
  }
760
1099
  }
761
1100
  }
762
1101
  }
763
1102
 
764
- template <int mmq_x, int mmq_y, int nwarps>
1103
+ template <int mmq_x, int mmq_y>
765
1104
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
766
1105
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1106
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1107
+ constexpr data_layout input_layout = get_input_data_layout();
1108
+ typedef tile<16, 8, int, input_layout> tile_A;
1109
+ typedef tile<16, 8, int, input_layout> tile_B;
1110
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
767
1111
 
768
- typedef tile<16, 8, int> tile_A;
769
- typedef tile< 8, 8, int> tile_B;
770
- typedef tile<16, 8, int> tile_C;
1112
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1113
+ constexpr int rows_per_warp = granularity;
1114
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1115
+
1116
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1117
+
1118
+ const int * x_qs = (const int *) x;
1119
+ const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
1120
+ const int * y_qs = (const int *) y + 4;
1121
+ const half2 * y_dm = (const half2 *) y;
1122
+
1123
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1124
+
1125
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1126
+ const int k0 = k00 + k01;
1127
+
1128
+ tile_A A[ntx];
1129
+ #pragma unroll
1130
+ for (int n = 0; n < ntx; ++n) {
1131
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
1132
+ }
1133
+
1134
+ #pragma unroll
1135
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1136
+ tile_B B;
1137
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1138
+
1139
+ const int j = j0 + tile_C::get_j(0);
1140
+ const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
1141
+
1142
+ #pragma unroll
1143
+ for (int n = 0; n < ntx; ++n) {
1144
+ tile_C C;
1145
+ mma(C, A[n], B);
1146
+
1147
+ #pragma unroll
1148
+ for (int l = 0; l < tile_C::ne; ++l) {
1149
+ const int i = i0 + n*tile_A::I + tile_C::get_i(l);
1150
+ float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
1151
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
1152
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;
1153
+ }
1154
+ }
1155
+ }
1156
+ }
1157
+ #else
1158
+ typedef tile<16, 8, int> tile_A;
1159
+ typedef tile< 8, 8, int> tile_B;
1160
+ typedef tile<16, 8, int> tile_C;
771
1161
 
772
1162
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
773
1163
  constexpr int rows_per_warp = 2 * granularity;
774
1164
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
775
1165
 
776
- y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
1166
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
777
1167
 
778
1168
  const int * x_qs = (const int *) x;
779
- const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
1169
+ const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
780
1170
  const int * y_qs = (const int *) y + 4;
781
1171
  const half2 * y_dm = (const half2 *) y;
782
1172
 
783
- tile_A A[ntx][WARP_SIZE/QI8_1];
784
- float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
1173
+ tile_A A[ntx][MMQ_TILE_NE_K/QI8_1];
1174
+ float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
785
1175
 
786
1176
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
787
1177
 
788
1178
  #pragma unroll
789
1179
  for (int n = 0; n < ntx; ++n) {
790
1180
  #pragma unroll
791
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1181
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
792
1182
  const int k0 = k00 + k01;
793
1183
 
794
1184
  load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
@@ -799,7 +1189,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
799
1189
  const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
800
1190
 
801
1191
  #pragma unroll
802
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1192
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
803
1193
  const int k0 = k00 + k01;
804
1194
 
805
1195
  dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
@@ -810,7 +1200,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
810
1200
  #pragma unroll
811
1201
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
812
1202
  #pragma unroll
813
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1203
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
814
1204
  tile_B B;
815
1205
  float2 dsB[tile_C::ne/2];
816
1206
 
@@ -836,11 +1226,15 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
836
1226
  }
837
1227
  }
838
1228
  }
1229
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
839
1230
  }
840
1231
 
841
- template <int mmq_x, int mmq_y, int nwarps>
1232
+ // Used for Q3_K, IQ2_S, and IQ2_XS
1233
+ template <int mmq_x, int mmq_y>
842
1234
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
843
1235
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1236
+ constexpr int nwarps = mmq_get_nwarps_device();
1237
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
844
1238
 
845
1239
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
846
1240
  const int * x_qs = (const int *) x;
@@ -849,7 +1243,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
849
1243
  const float * y_df = (const float *) y;
850
1244
 
851
1245
  // #pragma unroll
852
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
1246
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
853
1247
  const int k0 = k00 + k01;
854
1248
 
855
1249
  #pragma unroll
@@ -857,23 +1251,123 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
857
1251
  const int j = j0 + threadIdx.y;
858
1252
 
859
1253
  #pragma unroll
860
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1254
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
861
1255
  const int i = i0 + threadIdx.x;
862
1256
 
863
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
864
- &x_qs[i*(2*WARP_SIZE + 1) + k0],
1257
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
1258
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0],
865
1259
  &y_qs[j*MMQ_TILE_Y_K + k01],
866
- &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
1260
+ &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
867
1261
  y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
868
1262
  }
869
1263
  }
870
1264
  }
871
1265
  }
872
1266
 
873
- template <int mmq_x, int mmq_y, int nwarps>
1267
+ // Used for Q3_K, IQ2_S, and IQ2_XS:
1268
+ template <int mmq_x, int mmq_y>
874
1269
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
875
1270
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
876
- #ifdef NEW_MMA_AVAILABLE
1271
+ #if defined(AMD_MFMA_AVAILABLE)
1272
+ constexpr data_layout input_layout = get_input_data_layout();
1273
+ typedef tile<16, 8, int, input_layout> tile_A;
1274
+ typedef tile<16, 8, int, input_layout> tile_B;
1275
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1276
+ typedef tile<64, 2, int, input_layout> tile_load;
1277
+
1278
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1279
+ constexpr int rows_per_warp = granularity;
1280
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1281
+
1282
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1283
+
1284
+ const int * x_qs = (const int *) x;
1285
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1286
+ const int * y_qs = (const int *) y + 4;
1287
+ const float * y_df = (const float *) y;
1288
+
1289
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1290
+
1291
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1292
+ const int k0 = k00 + k01;
1293
+
1294
+ tile_A A[ntx];
1295
+ #pragma unroll
1296
+ for (int n = 0; n < ntx; ++n) {
1297
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1298
+ }
1299
+
1300
+ #pragma unroll
1301
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1302
+ tile_B B[1];
1303
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1304
+
1305
+ const int j = j0 + tile_C::get_j(0);
1306
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
1307
+
1308
+ #pragma unroll
1309
+ for (int n = 0; n < ntx; ++n) {
1310
+ tile_C C;
1311
+ mma(C, A[n], B[0]);
1312
+
1313
+ #pragma unroll
1314
+ for (int l = 0; l < tile_C::ne; ++l) {
1315
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1316
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
1317
+ }
1318
+ }
1319
+ }
1320
+ }
1321
+ #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1322
+ constexpr data_layout input_layout = get_input_data_layout();
1323
+ typedef tile<16, 4, int, input_layout> tile_A;
1324
+ typedef tile<16, 4, int, input_layout> tile_B;
1325
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1326
+
1327
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1328
+ constexpr int rows_per_warp = granularity;
1329
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1330
+
1331
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1332
+
1333
+ const int * x_qs = (const int *) x;
1334
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1335
+ const int * y_qs = (const int *) y + 4;
1336
+ const float * y_df = (const float *) y;
1337
+
1338
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1339
+
1340
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1341
+ const int k0 = k00 + k01;
1342
+
1343
+ tile_A A[ntx];
1344
+ #pragma unroll
1345
+ for (int n = 0; n < ntx; ++n) {
1346
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1347
+ }
1348
+
1349
+ #pragma unroll
1350
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1351
+ tile_B B;
1352
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1353
+
1354
+ const int j = j0 + tile_C::get_j(0);
1355
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
1356
+
1357
+ #pragma unroll
1358
+ for (int n = 0; n < ntx; ++n) {
1359
+ tile_C C;
1360
+ mma(C, A[n], B);
1361
+
1362
+ #pragma unroll
1363
+ for (int l = 0; l < tile_C::ne; ++l) {
1364
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1365
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
1366
+ }
1367
+ }
1368
+ }
1369
+ }
1370
+ #elif defined(TURING_MMA_AVAILABLE)
877
1371
 
878
1372
  typedef tile<16, 4, int> tile_A;
879
1373
  typedef tile<16, 8, int> tile_A_8;
@@ -884,10 +1378,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
884
1378
  constexpr int rows_per_warp = 2 * granularity;
885
1379
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
886
1380
 
887
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
1381
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
888
1382
 
889
1383
  const int * x_qs = (const int *) x;
890
- const float * x_df = (const float *) x_qs + WARP_SIZE*2;
1384
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
891
1385
  const int * y_qs = (const int *) y + 4;
892
1386
  const float * y_df = (const float *) y;
893
1387
 
@@ -899,7 +1393,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
899
1393
  #pragma unroll
900
1394
  for (int n = 0; n < ntx; ++n) {
901
1395
  #pragma unroll
902
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1396
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
903
1397
  const int k0 = k00 + k01;
904
1398
 
905
1399
  load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
@@ -910,7 +1404,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
910
1404
  const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
911
1405
 
912
1406
  #pragma unroll
913
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
1407
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
914
1408
  const int k0 = k00 + k01;
915
1409
 
916
1410
  dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
@@ -921,7 +1415,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
921
1415
  #pragma unroll
922
1416
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
923
1417
  #pragma unroll
924
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1418
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
925
1419
  tile_B B[2];
926
1420
  float dB[tile_C::ne/2];
927
1421
 
@@ -950,28 +1444,31 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
950
1444
  }
951
1445
  }
952
1446
  #else
953
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
1447
+ GGML_UNUSED_VARS(x, y, sum, k00);
954
1448
  NO_DEVICE_CODE;
955
- #endif // NEW_MMA_AVAILABLE
1449
+ #endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
956
1450
  }
957
1451
 
958
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
1452
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
959
1453
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1454
+ constexpr int nwarps = mmq_get_nwarps_device();
960
1455
 
961
- #ifdef NEW_MMA_AVAILABLE
1456
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
962
1457
  int * x_qs = (int *) x_tile;
963
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
1458
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
964
1459
  #else
965
1460
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
966
1461
  int * x_qs = (int *) x_tile;
967
1462
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
968
- #endif // NEW_MMA_AVAILABLE
1463
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
969
1464
 
970
- const int kqsx = threadIdx.x % QI2_K;
1465
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
1466
+ constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
1467
+ const int kqsx = threadIdx.x % threads_per_row;
971
1468
 
972
1469
  #pragma unroll
973
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) {
974
- int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K;
1470
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1471
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
975
1472
 
976
1473
  if (need_check) {
977
1474
  i = min(i, i_max);
@@ -987,11 +1484,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
987
1484
 
988
1485
  const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
989
1486
 
990
- #ifdef NEW_MMA_AVAILABLE
1487
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
991
1488
  x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
992
1489
  #else
993
- x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
994
- #endif // NEW_MMA_AVAILABLE
1490
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1491
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
995
1492
  }
996
1493
 
997
1494
  const int sc_m = bxi->scales[kqsx];
@@ -1002,17 +1499,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1002
1499
  const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
1003
1500
  #endif // FAST_FP16_AVAILABLE
1004
1501
 
1005
- #ifdef NEW_MMA_AVAILABLE
1502
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1006
1503
  x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
1007
1504
  #else
1008
- x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik;
1009
- #endif // NEW_MMA_AVAILABLE
1505
+ x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
1506
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1010
1507
  }
1011
1508
  }
1012
1509
 
1013
- template <int mmq_x, int mmq_y, int nwarps>
1510
+ template <int mmq_x, int mmq_y>
1014
1511
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1015
1512
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1513
+ constexpr int nwarps = mmq_get_nwarps_device();
1514
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1016
1515
 
1017
1516
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1018
1517
  const int * x_qs = (const int *) x;
@@ -1029,7 +1528,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1029
1528
  }
1030
1529
 
1031
1530
  #pragma unroll
1032
- for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1531
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1033
1532
  const int k0 = k00 + k01;
1034
1533
 
1035
1534
  #pragma unroll
@@ -1037,13 +1536,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1037
1536
  const int j = j0 + threadIdx.y;
1038
1537
 
1039
1538
  #pragma unroll
1040
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1539
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1041
1540
  const int i = i0 + threadIdx.x;
1042
1541
 
1043
1542
  constexpr int ns = 2;
1044
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1045
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1046
- &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1543
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1544
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1545
+ &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1047
1546
  &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1048
1547
  }
1049
1548
  }
@@ -1052,7 +1551,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1052
1551
  // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
1053
1552
  // As a workaround 2 separate loops are used instead.
1054
1553
  #pragma unroll
1055
- for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1554
+ for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1056
1555
  const int k0 = k00 + k01;
1057
1556
 
1058
1557
  #pragma unroll
@@ -1060,23 +1559,158 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1060
1559
  const int j = j0 + threadIdx.y;
1061
1560
 
1062
1561
  #pragma unroll
1063
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1562
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1064
1563
  const int i = i0 + threadIdx.x;
1065
1564
 
1066
1565
  constexpr int ns = 1;
1067
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1068
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1069
- &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1566
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1567
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1568
+ &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1070
1569
  &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1071
1570
  }
1072
1571
  }
1073
1572
  }
1074
1573
  }
1075
1574
 
1076
- template <int mmq_x, int mmq_y, int nwarps>
1575
+ template <int mmq_x, int mmq_y>
1077
1576
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1078
1577
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1079
- #ifdef NEW_MMA_AVAILABLE
1578
+ #if defined(AMD_MFMA_AVAILABLE)
1579
+ constexpr data_layout input_layout = get_input_data_layout();
1580
+ typedef tile<16, 8, int, input_layout> tile_A;
1581
+ typedef tile<16, 8, int, input_layout> tile_B;
1582
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1583
+ typedef tile<64, 2, int, input_layout> tile_load;
1584
+
1585
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1586
+ constexpr int rows_per_warp = granularity;
1587
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1588
+
1589
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1590
+
1591
+ const int * x_qs = (const int *) x;
1592
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1593
+ const int * y_qs = (const int *) y + 4;
1594
+ const half2 * y_ds = (const half2 *) y;
1595
+
1596
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1597
+
1598
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1599
+ const int k0 = k00 + k01;
1600
+
1601
+ tile_A A[ntx];
1602
+ #pragma unroll
1603
+ for (int n = 0; n < ntx; ++n) {
1604
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1605
+ }
1606
+
1607
+ #pragma unroll
1608
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1609
+ tile_B B[1];
1610
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1611
+
1612
+ const int j = j0 + tile_C::get_j(0);
1613
+ const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
1614
+ const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
1615
+ : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1616
+ : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
1617
+
1618
+ tile_C Cm;
1619
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1620
+ tile_A A1;
1621
+ A1.x[0] = 0x01010101;
1622
+ A1.x[1] = 0x01010101;
1623
+ mma(Cm, A1, B[0]);
1624
+ }
1625
+
1626
+ #pragma unroll
1627
+ for (int n = 0; n < ntx; ++n) {
1628
+ tile_C Cd;
1629
+ mma(Cd, A[n], B[0]);
1630
+
1631
+ #pragma unroll
1632
+ for (int l = 0; l < tile_C::ne; ++l) {
1633
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1634
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
1635
+ float tmp = Cd.x[l]*dm.x;
1636
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1637
+ tmp -= Cm.x[l]*dm.y;
1638
+ }
1639
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
1640
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
1641
+ }
1642
+ }
1643
+ }
1644
+ }
1645
+ #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1646
+ constexpr data_layout input_layout = get_input_data_layout();
1647
+ typedef tile<16, 4, int, input_layout> tile_A;
1648
+ typedef tile<16, 4, int, input_layout> tile_B;
1649
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1650
+
1651
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1652
+ constexpr int rows_per_warp = granularity;
1653
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1654
+
1655
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1656
+
1657
+ const int * x_qs = (const int *) x;
1658
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1659
+ const int * y_qs = (const int *) y + 4;
1660
+ const half2 * y_ds = (const half2 *) y;
1661
+
1662
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1663
+
1664
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1665
+ const int k0 = k00 + k01;
1666
+
1667
+ tile_A A[ntx];
1668
+ #pragma unroll
1669
+ for (int n = 0; n < ntx; ++n) {
1670
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1671
+ }
1672
+
1673
+ #pragma unroll
1674
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1675
+ tile_B B;
1676
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1677
+
1678
+ const int j = j0 + tile_C::get_j(0);
1679
+ const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
1680
+ const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
1681
+ : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1682
+ : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
1683
+
1684
+ tile_C Cm;
1685
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1686
+ tile_A A1;
1687
+ #pragma unroll
1688
+ for (int l = 0; l < tile_A::ne; ++l) {
1689
+ A1.x[l] = 0x01010101;
1690
+ }
1691
+ mma(Cm, A1, B);
1692
+ }
1693
+
1694
+ #pragma unroll
1695
+ for (int n = 0; n < ntx; ++n) {
1696
+ tile_C Cd;
1697
+ mma(Cd, A[n], B);
1698
+
1699
+ #pragma unroll
1700
+ for (int l = 0; l < tile_C::ne; ++l) {
1701
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1702
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
1703
+ float tmp = Cd.x[l]*dm.x;
1704
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1705
+ tmp -= Cm.x[l]*dm.y;
1706
+ }
1707
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
1708
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
1709
+ }
1710
+ }
1711
+ }
1712
+ }
1713
+ #elif defined(TURING_MMA_AVAILABLE)
1080
1714
 
1081
1715
  typedef tile<16, 4, int> tile_A;
1082
1716
  typedef tile<16, 8, int> tile_A_8;
@@ -1087,10 +1721,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1087
1721
  constexpr int rows_per_warp = 2 * granularity;
1088
1722
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1089
1723
 
1090
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
1724
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1091
1725
 
1092
1726
  const int * x_qs = (const int *) x;
1093
- const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
1727
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1094
1728
  const int * y_qs = (const int *) y + 4;
1095
1729
  const half2 * y_ds = (const half2 *) y;
1096
1730
 
@@ -1103,7 +1737,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1103
1737
  #pragma unroll
1104
1738
  for (int n = 0; n < ntx; ++n) {
1105
1739
  #pragma unroll
1106
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1740
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1107
1741
  const int k0 = k00 + k01;
1108
1742
 
1109
1743
  load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
@@ -1117,7 +1751,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1117
1751
  const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
1118
1752
 
1119
1753
  #pragma unroll
1120
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
1754
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) {
1121
1755
  const int k0 = k00 + k01;
1122
1756
 
1123
1757
  const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
@@ -1140,7 +1774,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1140
1774
  }
1141
1775
 
1142
1776
  #pragma unroll
1143
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1777
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1144
1778
  tile_B B[2];
1145
1779
 
1146
1780
  // Here load_generic is faster than load_ldmatrix.
@@ -1148,7 +1782,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1148
1782
  load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
1149
1783
 
1150
1784
  tile_C Cm[2];
1151
- if (k01 >= WARP_SIZE * 3/4) {
1785
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1152
1786
  tile_A A1;
1153
1787
  A1.x[0] = 0x01010101;
1154
1788
  A1.x[1] = 0x01010101;
@@ -1166,16 +1800,16 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1166
1800
  #pragma unroll
1167
1801
  for (int l = 0; l < tile_C::ne; ++l) {
1168
1802
  float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
1169
- if (k01 >= WARP_SIZE * 3/4) {
1803
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1170
1804
  tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
1171
1805
  }
1172
- sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
1806
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y);
1173
1807
  }
1174
1808
  }
1175
1809
  }
1176
1810
 
1177
1811
  #pragma unroll
1178
- for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
1812
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) {
1179
1813
  float2 sB[tile_C::ne/2];
1180
1814
 
1181
1815
  #pragma unroll
@@ -1196,29 +1830,33 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1196
1830
  }
1197
1831
  }
1198
1832
  #else
1199
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
1833
+ GGML_UNUSED_VARS(x, y, sum, k00);
1200
1834
  NO_DEVICE_CODE;
1201
- #endif // NEW_MMA_AVAILABLE
1835
+ #endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
1202
1836
  }
1203
1837
 
1204
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
1838
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
1205
1839
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1840
+ constexpr int nwarps = mmq_get_nwarps_device();
1841
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1206
1842
 
1207
- #ifdef NEW_MMA_AVAILABLE
1843
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1208
1844
  int * x_qs = (int *) x_tile;
1209
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
1845
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1210
1846
  #else
1211
1847
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1212
1848
  int * x_qs = (int *) x_tile;
1213
1849
  float * x_df = (float *) (x_qs + txs.qs);
1214
1850
  int * x_sc = (int *) (x_df + txs.dm);
1215
- #endif // NEW_MMA_AVAILABLE
1851
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1216
1852
 
1217
- const int kqsx = threadIdx.x % QI3_K;
1853
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
1854
+ constexpr int nrows = warp_size / threads_per_row;
1855
+ const int kqsx = threadIdx.x % threads_per_row;
1218
1856
 
1219
1857
  #pragma unroll
1220
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) {
1221
- int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K;
1858
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1859
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1222
1860
 
1223
1861
  if (need_check) {
1224
1862
  i = min(i, i_max);
@@ -1238,17 +1876,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1238
1876
 
1239
1877
  const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
1240
1878
 
1241
- #ifdef NEW_MMA_AVAILABLE
1879
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1242
1880
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
1243
1881
  #else
1244
- x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
1245
- #endif // NEW_MMA_AVAILABLE
1882
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1883
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1246
1884
  }
1247
1885
  }
1248
1886
 
1887
+ constexpr int rows_per_warp = warp_size / 4;
1249
1888
  #pragma unroll
1250
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
1251
- int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
1889
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1890
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
1252
1891
 
1253
1892
  if (need_check) {
1254
1893
  i = min(i, i_max);
@@ -1256,7 +1895,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1256
1895
 
1257
1896
  const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1258
1897
 
1259
- const int ksc = threadIdx.x % (WARP_SIZE/8);
1898
+ const int ksc = threadIdx.x % 4;
1260
1899
 
1261
1900
  const int ksc_low = ksc % (QI3_K/8);
1262
1901
  const int shift_low = 4 * (ksc / (QI3_K/8));
@@ -1268,23 +1907,23 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1268
1907
 
1269
1908
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1270
1909
 
1271
- #ifdef NEW_MMA_AVAILABLE
1910
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1272
1911
  const int8_t * sc8 = (const int8_t *) &sc;
1273
1912
  const float d = bxi->d;
1274
1913
 
1275
1914
  #pragma unroll
1276
1915
  for (int l = 0; l < int(sizeof(int)); ++l) {
1277
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
1916
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
1278
1917
  }
1279
1918
  #else
1280
- x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
1281
- #endif // NEW_MMA_AVAILABLE
1919
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
1920
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1282
1921
  }
1283
1922
 
1284
- #ifndef NEW_MMA_AVAILABLE
1923
+ #if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE))
1285
1924
  #pragma unroll
1286
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
1287
- int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
1925
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1926
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1288
1927
 
1289
1928
  if (need_check) {
1290
1929
  i = min(i, i_max);
@@ -1294,12 +1933,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1294
1933
 
1295
1934
  x_df[i] = bxi->d;
1296
1935
  }
1297
- #endif // NEW_MMA_AVAILABLE
1936
+ #endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE)
1298
1937
  }
1299
1938
 
1300
- template <int mmq_x, int mmq_y, int nwarps>
1939
+ template <int mmq_x, int mmq_y>
1301
1940
  static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1302
1941
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1942
+ constexpr int nwarps = mmq_get_nwarps_device();
1943
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1303
1944
 
1304
1945
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1305
1946
  const int * x_qs = (const int *) x;
@@ -1309,7 +1950,7 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1309
1950
  const float * y_df = (const float *) y;
1310
1951
 
1311
1952
  // #pragma unroll
1312
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1953
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1313
1954
  const int k0 = k00 + k01;
1314
1955
 
1315
1956
  #pragma unroll
@@ -1317,13 +1958,13 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1317
1958
  const int j = j0 + threadIdx.y;
1318
1959
 
1319
1960
  #pragma unroll
1320
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1961
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1321
1962
  const int i = i0 + threadIdx.x;
1322
1963
 
1323
- const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4;
1964
+ const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4;
1324
1965
 
1325
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
1326
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
1966
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq(
1967
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
1327
1968
  x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1328
1969
  }
1329
1970
  }
@@ -1340,72 +1981,85 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
1340
1981
  ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
1341
1982
  }
1342
1983
 
1343
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1984
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1344
1985
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1986
+ constexpr int nwarps = mmq_get_nwarps_device();
1987
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1345
1988
 
1346
- #ifdef NEW_MMA_AVAILABLE
1989
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1347
1990
  int * x_qs = (int *) x_tile;
1348
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
1991
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1349
1992
  #else
1350
1993
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1351
1994
  int * x_qs = (int *) x_tile;
1352
1995
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1353
1996
  int * x_sc = (int *) (x_dm + txs.dm);
1354
- #endif // NEW_MMA_AVAILABLE
1997
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1998
+
1999
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
2000
+ constexpr int nrows = warp_size / threads_per_row;
2001
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1355
2002
 
1356
2003
  #pragma unroll
1357
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1358
- int i = i0 + threadIdx.y;
2004
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2005
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1359
2006
 
1360
2007
  if (need_check) {
1361
2008
  i = min(i, i_max);
1362
2009
  }
1363
2010
 
1364
2011
  const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1365
- const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
2012
+ const int qs0 = get_int_b4(bxi->qs, txi);
1366
2013
 
1367
- #ifdef NEW_MMA_AVAILABLE
1368
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
1369
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
2014
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2015
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
2016
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
1370
2017
  #else
1371
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
1372
- #endif // NEW_MMA_AVAILABLE
2018
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
2019
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1373
2020
  }
1374
2021
 
1375
- #ifdef NEW_MMA_AVAILABLE
1376
-
2022
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2023
+ constexpr int rows_per_warp = warp_size / 2;
1377
2024
  #pragma unroll
1378
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
1379
- int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
1380
-
1381
- if (need_check) {
1382
- i = min(i, i_max);
1383
- }
2025
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2026
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2027
+ // Need if on AMD instead of % because warp_size == 64
2028
+ // This causes double work and throughput loss (MI300X)
2029
+ // H100 loses about 100 t/s with 'if' condition over '%'
2030
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
2031
+ if (i < mmq_y) {
2032
+ #else
2033
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
2034
+ {
2035
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2036
+ if (need_check) {
2037
+ i = min(i, i_max);
2038
+ }
1384
2039
 
1385
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
2040
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1386
2041
 
1387
- const int * scales = (const int *) bxi->scales;
1388
- const int ksc = threadIdx.x % (WARP_SIZE/16);
2042
+ const int * scales = (const int *) bxi->scales;
2043
+ const int ksc = threadIdx.x % 2;
1389
2044
 
1390
- const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1391
- const int m32 = unpack_scales_q45_K(scales, ksc + 2);
2045
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
2046
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1392
2047
 
1393
- const uint8_t * sc8 = (const uint8_t *) &sc32;
1394
- const uint8_t * m8 = (const uint8_t *) &m32;
2048
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
2049
+ const uint8_t * m8 = (const uint8_t *) &m32;
1395
2050
 
1396
- const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
2051
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1397
2052
 
1398
- #pragma unroll
1399
- for (int l = 0; l < int(sizeof(int)); ++l) {
1400
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
2053
+ #pragma unroll
2054
+ for (int l = 0; l < sizeof(int); ++l) {
2055
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
2056
+ }
1401
2057
  }
1402
2058
  }
1403
-
1404
2059
  #else
1405
-
1406
2060
  #pragma unroll
1407
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) {
1408
- int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y;
2061
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
2062
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1409
2063
 
1410
2064
  if (need_check) {
1411
2065
  i = min(i, i_max);
@@ -1415,30 +2069,32 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1415
2069
 
1416
2070
  x_dm[i] = bxi->dm;
1417
2071
  }
1418
-
2072
+ constexpr int rows_per_warp = warp_size / 4;
1419
2073
  #pragma unroll
1420
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
1421
- int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
2074
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2075
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
1422
2076
 
1423
2077
  if (need_check) {
1424
2078
  i = min(i, i_max);
1425
2079
  }
1426
2080
 
1427
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
2081
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8);
1428
2082
 
1429
2083
  const int * scales = (const int *) bxi->scales;
1430
2084
 
1431
- const int ksc = threadIdx.x % (WARP_SIZE/8);
2085
+ const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
1432
2086
  const int scales8 = unpack_scales_q45_K(scales, ksc);
1433
2087
 
1434
- x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
2088
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1435
2089
  }
1436
- #endif // NEW_MMA_AVAILABLE
2090
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1437
2091
  }
1438
2092
 
1439
- template <int mmq_x, int mmq_y, int nwarps>
2093
+ template <int mmq_x, int mmq_y>
1440
2094
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1441
2095
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2096
+ constexpr int nwarps = mmq_get_nwarps_device();
2097
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1442
2098
 
1443
2099
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1444
2100
  const int * x_qs = (const int *) x;
@@ -1448,7 +2104,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1448
2104
  const half2 * y_ds = (const half2 *) y;
1449
2105
 
1450
2106
  // #pragma unroll
1451
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
2107
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
1452
2108
  const int k0 = k00 + k01;
1453
2109
 
1454
2110
  #pragma unroll
@@ -1456,97 +2112,110 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1456
2112
  const int j = j0 + threadIdx.y;
1457
2113
 
1458
2114
  #pragma unroll
1459
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
2115
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1460
2116
  const int i = i0 + threadIdx.x;
1461
2117
 
1462
- const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16);
2118
+ const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16);
1463
2119
 
1464
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
1465
- &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
2120
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq(
2121
+ &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1466
2122
  x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1467
2123
  }
1468
2124
  }
1469
2125
  }
1470
2126
  }
1471
2127
 
1472
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
2128
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1473
2129
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2130
+ constexpr int nwarps = mmq_get_nwarps_device();
2131
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1474
2132
 
1475
- #ifdef NEW_MMA_AVAILABLE
2133
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1476
2134
  int * x_qs = (int *) x_tile;
1477
- half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
2135
+ half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
1478
2136
  #else
1479
2137
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1480
2138
  int * x_qs = (int *) x_tile;
1481
2139
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1482
2140
  int * x_sc = (int *) (x_dm + txs.dm);
1483
- #endif // NEW_MMA_AVAILABLE
2141
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2142
+
2143
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
2144
+ constexpr int nrows = warp_size / threads_per_row;
2145
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1484
2146
 
1485
2147
  #pragma unroll
1486
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1487
- int i = i0 + threadIdx.y;
2148
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2149
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1488
2150
 
1489
2151
  if (need_check) {
1490
2152
  i = min(i, i_max);
1491
2153
  }
1492
2154
 
1493
2155
  const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1494
- const int ky = QR5_K*threadIdx.x;
2156
+ const int ky = QR5_K*txi;
1495
2157
 
1496
- const int ql = get_int_b4(bxi->qs, threadIdx.x);
2158
+ const int ql = get_int_b4(bxi->qs, txi);
1497
2159
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1498
2160
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1499
2161
 
1500
- const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4));
1501
- const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010;
1502
- const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010;
2162
+ const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4));
2163
+ const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010;
2164
+ const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010;
1503
2165
 
1504
- const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
1505
- const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
2166
+ const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
2167
+ const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
1506
2168
 
1507
- #ifdef NEW_MMA_AVAILABLE
2169
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1508
2170
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
1509
2171
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
1510
2172
  #else
1511
- x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
1512
- x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
1513
- #endif // NEW_MMA_AVAILABLE
2173
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
2174
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
2175
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1514
2176
  }
1515
2177
 
1516
- #ifdef NEW_MMA_AVAILABLE
1517
-
2178
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2179
+ constexpr int rows_per_warp = warp_size / 2;
1518
2180
  #pragma unroll
1519
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
1520
- int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
1521
-
1522
- if (need_check) {
1523
- i = min(i, i_max);
1524
- }
2181
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2182
+ #if defined(AMD_MFMA_AVAILABLE)
2183
+ // Need if on AMD instead of % because warp_size == 64
2184
+ // This causes double work and throughput loss (MI300X)
2185
+ // H100 loses about 100 t/s with 'if' condition over '%'
2186
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
2187
+ if (i < mmq_y) {
2188
+ #else
2189
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
2190
+ {
2191
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2192
+ if (need_check) {
2193
+ i = min(i, i_max);
2194
+ }
1525
2195
 
1526
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
2196
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1527
2197
 
1528
- const int * scales = (const int *) bxi->scales;
1529
- const int ksc = threadIdx.x % (WARP_SIZE/16);
2198
+ const int * scales = (const int *) bxi->scales;
2199
+ const int ksc = threadIdx.x % 2;
1530
2200
 
1531
- const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1532
- const int m32 = unpack_scales_q45_K(scales, ksc + 2);
2201
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
2202
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1533
2203
 
1534
- const uint8_t * sc8 = (const uint8_t *) &sc32;
1535
- const uint8_t * m8 = (const uint8_t *) &m32;
2204
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
2205
+ const uint8_t * m8 = (const uint8_t *) &m32;
1536
2206
 
1537
- const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
2207
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1538
2208
 
1539
2209
  #pragma unroll
1540
- for (int l = 0; l < int(sizeof(int)); ++l) {
1541
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
2210
+ for (int l = 0; l < int(sizeof(int)); ++l) {
2211
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
2212
+ }
1542
2213
  }
1543
2214
  }
1544
-
1545
2215
  #else
1546
-
1547
2216
  #pragma unroll
1548
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) {
1549
- int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y;
2217
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
2218
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1550
2219
 
1551
2220
  if (need_check) {
1552
2221
  i = min(i, i_max);
@@ -1557,9 +2226,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1557
2226
  x_dm[i] = bxi->dm;
1558
2227
  }
1559
2228
 
2229
+ constexpr int rows_per_warp = warp_size / 4;
1560
2230
  #pragma unroll
1561
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
1562
- int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y;
2231
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2232
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
1563
2233
 
1564
2234
  if (need_check) {
1565
2235
  i = min(i, i_max);
@@ -1569,17 +2239,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1569
2239
 
1570
2240
  const int * scales = (const int *) bxi->scales;
1571
2241
 
1572
- const int ksc = threadIdx.x % (WARP_SIZE/8);
2242
+ const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
1573
2243
  const int scales8 = unpack_scales_q45_K(scales, ksc);
1574
2244
 
1575
- x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
2245
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1576
2246
  }
1577
- #endif // NEW_MMA_AVAILABLE
2247
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1578
2248
  }
1579
2249
 
1580
- template <int mmq_x, int mmq_y, int nwarps>
2250
+ template <int mmq_x, int mmq_y>
1581
2251
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1582
2252
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2253
+ constexpr int nwarps = mmq_get_nwarps_device();
2254
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1583
2255
 
1584
2256
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1585
2257
  const int * x_qs = (const int *) x;
@@ -1589,7 +2261,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1589
2261
  const half2 * y_ds = (const half2 *) y;
1590
2262
 
1591
2263
  // #pragma unroll
1592
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
2264
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
1593
2265
  const int k0 = k00 + k01;
1594
2266
 
1595
2267
  #pragma unroll
@@ -1597,36 +2269,42 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1597
2269
  const int j = j0 + threadIdx.y;
1598
2270
 
1599
2271
  #pragma unroll
1600
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
2272
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1601
2273
  const int i = i0 + threadIdx.x;
1602
2274
 
1603
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
2275
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16);
1604
2276
 
1605
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
1606
- &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
2277
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq(
2278
+ &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1607
2279
  x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1608
2280
  }
1609
2281
  }
1610
2282
  }
1611
2283
  }
1612
2284
 
1613
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
2285
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
1614
2286
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2287
+ constexpr int nwarps = mmq_get_nwarps_device();
2288
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1615
2289
 
1616
- #ifdef NEW_MMA_AVAILABLE
2290
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1617
2291
  int * x_qs = (int *) x_tile;
1618
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
1619
- int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
2292
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2293
+ int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
1620
2294
  #else
1621
2295
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
1622
2296
  int * x_qs = (int *) x_tile;
1623
2297
  float * x_df = (float *) (x_qs + txs.qs);
1624
2298
  int * x_sc = (int *) (x_df + txs.dm);
1625
- #endif // NEW_MMA_AVAILABLE
2299
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2300
+
2301
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
2302
+ constexpr int nrows = warp_size / threads_per_row;
2303
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1626
2304
 
1627
2305
  #pragma unroll
1628
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1629
- int i = i0 + threadIdx.y;
2306
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2307
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1630
2308
 
1631
2309
  if (need_check) {
1632
2310
  i = min(i, i_max);
@@ -1634,67 +2312,67 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1634
2312
 
1635
2313
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
1636
2314
 
1637
- const int ql = get_int_b2(bxi->ql, threadIdx.x);
2315
+ const int ql = get_int_b2(bxi->ql, txi);
1638
2316
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1639
2317
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1640
2318
 
1641
- const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4));
1642
- const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030;
1643
- const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030;
2319
+ const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4));
2320
+ const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030;
2321
+ const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030;
1644
2322
 
1645
- const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
1646
- const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
2323
+ const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
2324
+ const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
1647
2325
 
1648
- #ifdef NEW_MMA_AVAILABLE
2326
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1649
2327
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
1650
2328
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
1651
2329
  #else
1652
- x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
1653
- x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
1654
- #endif // NEW_MMA_AVAILABLE
2330
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2331
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2332
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1655
2333
  }
1656
2334
 
1657
- const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
1658
- const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
1659
-
1660
2335
  #pragma unroll
1661
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
1662
- int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
2336
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
2337
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1663
2338
 
1664
2339
  if (need_check) {
1665
2340
  i = min(i, i_max);
1666
2341
  }
1667
2342
 
1668
- const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
2343
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
1669
2344
 
1670
- #ifdef NEW_MMA_AVAILABLE
1671
- x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d;
2345
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2346
+ x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
1672
2347
  #else
1673
- x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
1674
- #endif // NEW_MMA_AVAILABLE
2348
+ x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
2349
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1675
2350
  }
1676
2351
 
2352
+ constexpr int rows_per_warp = warp_size / 4;
1677
2353
  #pragma unroll
1678
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
1679
- int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
2354
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2355
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
1680
2356
 
1681
2357
  if (need_check) {
1682
2358
  i = min(i, i_max);
1683
2359
  }
1684
2360
 
1685
- const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
2361
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
1686
2362
 
1687
- #ifdef NEW_MMA_AVAILABLE
1688
- x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
2363
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2364
+ x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
1689
2365
  #else
1690
- x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
1691
- #endif // NEW_MMA_AVAILABLE
2366
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
2367
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1692
2368
  }
1693
2369
  }
1694
2370
 
1695
- template <int mmq_x, int mmq_y, int nwarps>
2371
+ template <int mmq_x, int mmq_y>
1696
2372
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1697
2373
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2374
+ constexpr int nwarps = mmq_get_nwarps_device();
2375
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1698
2376
 
1699
2377
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
1700
2378
  const int * x_qs = (const int *) x;
@@ -1704,7 +2382,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1704
2382
  const float * y_df = (const float *) y;
1705
2383
 
1706
2384
  // #pragma unroll
1707
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
2385
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
1708
2386
  const int k0 = k00 + k01;
1709
2387
 
1710
2388
  #pragma unroll
@@ -1712,23 +2390,126 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1712
2390
  const int j = j0 + threadIdx.y;
1713
2391
 
1714
2392
  #pragma unroll
1715
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
2393
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1716
2394
  const int i = i0 + threadIdx.x;
1717
2395
 
1718
- const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]);
2396
+ const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]);
1719
2397
 
1720
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
1721
- &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
1722
- x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
2398
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq(
2399
+ &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
2400
+ x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1723
2401
  }
1724
2402
  }
1725
2403
  }
1726
2404
  }
1727
2405
 
1728
- template <int mmq_x, int mmq_y, int nwarps>
2406
+ template <int mmq_x, int mmq_y>
1729
2407
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1730
2408
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1731
- #ifdef NEW_MMA_AVAILABLE
2409
+ #if defined(AMD_MFMA_AVAILABLE)
2410
+ constexpr data_layout input_layout = get_input_data_layout();
2411
+ typedef tile<16, 8, int, input_layout> tile_A;
2412
+ typedef tile<16, 8, int, input_layout> tile_B;
2413
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
2414
+ typedef tile<64, 2, int, input_layout> tile_load;
2415
+
2416
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
2417
+ constexpr int rows_per_warp = granularity;
2418
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2419
+
2420
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2421
+
2422
+ const int * x_qs = (const int *) x;
2423
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2424
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2425
+ const int * y_qs = (const int *) y + 4;
2426
+ const float * y_df = (const float *) y;
2427
+
2428
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
2429
+
2430
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2431
+ const int k0 = k00 + k01;
2432
+
2433
+ tile_A A[ntx];
2434
+ #pragma unroll
2435
+ for (int n = 0; n < ntx; ++n) {
2436
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2437
+ }
2438
+
2439
+ #pragma unroll
2440
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2441
+ tile_B B[1];
2442
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2443
+
2444
+ const int j = j0 + tile_C::get_j(0);
2445
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
2446
+
2447
+ #pragma unroll
2448
+ for (int n = 0; n < ntx; ++n) {
2449
+ tile_C C;
2450
+ mma(C, A[n], B[0]);
2451
+
2452
+ #pragma unroll
2453
+ for (int l = 0; l < tile_C::ne; ++l) {
2454
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2455
+ const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2456
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2457
+ }
2458
+ }
2459
+ }
2460
+ }
2461
+ #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
2462
+ constexpr data_layout input_layout = get_input_data_layout();
2463
+ typedef tile<16, 4, int, input_layout> tile_A;
2464
+ typedef tile<16, 4, int, input_layout> tile_B;
2465
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
2466
+
2467
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
2468
+ constexpr int rows_per_warp = granularity;
2469
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2470
+
2471
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2472
+
2473
+ const int * x_qs = (const int *) x;
2474
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2475
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2476
+ const int * y_qs = (const int *) y + 4;
2477
+ const float * y_df = (const float *) y;
2478
+
2479
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
2480
+
2481
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2482
+ const int k0 = k00 + k01;
2483
+
2484
+ tile_A A[ntx];
2485
+ #pragma unroll
2486
+ for (int n = 0; n < ntx; ++n) {
2487
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2488
+ }
2489
+
2490
+ #pragma unroll
2491
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2492
+ tile_B B;
2493
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2494
+
2495
+ const int j = j0 + tile_C::get_j(0);
2496
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
2497
+
2498
+ #pragma unroll
2499
+ for (int n = 0; n < ntx; ++n) {
2500
+ tile_C C;
2501
+ mma(C, A[n], B);
2502
+
2503
+ #pragma unroll
2504
+ for (int l = 0; l < tile_C::ne; ++l) {
2505
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2506
+ const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2507
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2508
+ }
2509
+ }
2510
+ }
2511
+ }
2512
+ #elif defined(TURING_MMA_AVAILABLE)
1732
2513
 
1733
2514
  typedef tile<16, 4, int> tile_A;
1734
2515
  typedef tile< 8, 4, int> tile_B;
@@ -1738,11 +2519,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1738
2519
  constexpr int rows_per_warp = 2 * granularity;
1739
2520
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1740
2521
 
1741
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
2522
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1742
2523
 
1743
2524
  const int * x_qs = (const int *) x;
1744
- const float * x_df = (const float *) x_qs + WARP_SIZE*2;
1745
- const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K;
2525
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2526
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
1746
2527
  const int * y_qs = (const int *) y + 4;
1747
2528
  const float * y_df = (const float *) y;
1748
2529
 
@@ -1755,7 +2536,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1755
2536
  #pragma unroll
1756
2537
  for (int n = 0; n < ntx; ++n) {
1757
2538
  #pragma unroll
1758
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
2539
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
1759
2540
  const int k0 = k00 + k01;
1760
2541
 
1761
2542
  load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
@@ -1763,7 +2544,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1763
2544
  }
1764
2545
 
1765
2546
  #pragma unroll
1766
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
2547
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) {
1767
2548
  const int k0 = k00 + k01;
1768
2549
 
1769
2550
  #pragma unroll
@@ -1793,7 +2574,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1793
2574
  float tmp[ntx][tile_C::ne] = {{0.0f}};
1794
2575
 
1795
2576
  #pragma unroll
1796
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
2577
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
1797
2578
  tile_B B[2];
1798
2579
  float dB[tile_C::ne/2];
1799
2580
 
@@ -1830,29 +2611,34 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1830
2611
  }
1831
2612
  }
1832
2613
  #else
1833
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
2614
+ GGML_UNUSED_VARS(x, y, sum, k00);
1834
2615
  NO_DEVICE_CODE;
1835
- #endif // NEW_MMA_AVAILABLE
2616
+ #endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
1836
2617
  }
1837
2618
 
1838
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
2619
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
1839
2620
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2621
+ constexpr int nwarps = mmq_get_nwarps_device();
2622
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1840
2623
 
1841
- #ifdef NEW_MMA_AVAILABLE
2624
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1842
2625
  int * x_qs = (int *) x_tile;
1843
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2626
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1844
2627
  #else
1845
2628
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
1846
2629
  int * x_qs = (int *) x_tile;
1847
2630
  float * x_df = (float *) (x_qs + txs.qs);
1848
- #endif // NEW_MMA_AVAILABLE
2631
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1849
2632
 
1850
- const int kbx = threadIdx.x / QI4_NL;
1851
- const int kqsx = threadIdx.x % QI4_NL;
2633
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
2634
+ constexpr int nrows = warp_size / threads_per_row;
2635
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2636
+ const int kbx = txi / QI4_NL;
2637
+ const int kqsx = txi % QI4_NL;
1852
2638
 
1853
2639
  #pragma unroll
1854
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1855
- int i = i0 + threadIdx.y;
2640
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2641
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1856
2642
 
1857
2643
  if (need_check) {
1858
2644
  i = min(i, i_max);
@@ -1861,23 +2647,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1861
2647
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
1862
2648
 
1863
2649
  const int aux_q4 = get_int_b2(bxi->qs, kqsx);
1864
- const int2 v = get_int_from_table_16(aux_q4);
1865
- const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
1866
- #ifdef NEW_MMA_AVAILABLE
1867
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
1868
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2650
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2651
+ const int k0 = kbx * (2 * QI4_NL) + kqsx;
2652
+
2653
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2654
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2655
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
1869
2656
  #else
1870
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
1871
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
1872
- #endif // NEW_MMA_AVAILABLE
2657
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2658
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
2659
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1873
2660
  }
1874
2661
 
1875
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
2662
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
2663
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
1876
2664
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
1877
2665
 
1878
2666
  #pragma unroll
1879
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
1880
- int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
2667
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
2668
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
1881
2669
 
1882
2670
  if (need_check) {
1883
2671
  i = min(i, i_max);
@@ -1885,31 +2673,35 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1885
2673
 
1886
2674
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
1887
2675
 
1888
- #ifdef NEW_MMA_AVAILABLE
1889
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
2676
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2677
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
1890
2678
  #else
1891
- x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
1892
- #endif // NEW_MMA_AVAILABLE
2679
+ x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
2680
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1893
2681
  }
1894
2682
  }
1895
2683
 
1896
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
2684
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
1897
2685
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2686
+ constexpr int nwarps = mmq_get_nwarps_device();
2687
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1898
2688
 
1899
- #ifdef NEW_MMA_AVAILABLE
2689
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1900
2690
  int * x_qs = (int *) x_tile;
1901
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2691
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1902
2692
  #else
1903
2693
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
1904
2694
  int * x_qs = (int *) x_tile;
1905
2695
  float * x_df = (float *) (x_qs + txs.qs);
1906
- #endif // NEW_MMA_AVAILABLE
2696
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1907
2697
 
1908
- const int kqsx = threadIdx.x % (QI2_XXS/2);
2698
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
2699
+ constexpr int nrows = warp_size / threads_per_row;
2700
+ const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1909
2701
 
1910
2702
  #pragma unroll
1911
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) {
1912
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2);
2703
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2704
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1913
2705
 
1914
2706
  if (need_check) {
1915
2707
  i = min(i, i_max);
@@ -1932,42 +2724,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1932
2724
  const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
1933
2725
  const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
1934
2726
 
1935
- #ifdef NEW_MMA_AVAILABLE
2727
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1936
2728
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
1937
2729
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
1938
2730
  #else
1939
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0;
1940
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1;
1941
- #endif // NEW_MMA_AVAILABLE
2731
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
2732
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
2733
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1942
2734
  }
1943
2735
 
1944
2736
  const int ls = aux32 >> 28;
1945
2737
  const float d = bxi->d;
1946
- #ifdef NEW_MMA_AVAILABLE
1947
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
2738
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2739
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
1948
2740
  #else
1949
- x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4;
1950
- #endif // NEW_MMA_AVAILABLE
2741
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
2742
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1951
2743
  }
1952
2744
  }
1953
2745
 
1954
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
2746
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
1955
2747
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2748
+ constexpr int nwarps = mmq_get_nwarps_device();
2749
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1956
2750
 
1957
- #ifdef NEW_MMA_AVAILABLE
2751
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1958
2752
  int * x_qs = (int *) x_tile;
1959
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2753
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1960
2754
  #else
1961
2755
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
1962
2756
  int * x_qs = (int *) x_tile;
1963
2757
  float * x_df = (float *) (x_qs + txs.qs);
1964
- #endif // NEW_MMA_AVAILABLE
2758
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1965
2759
 
1966
- const int kqsx = threadIdx.x % (QI2_XS/2);
2760
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
2761
+ constexpr int nrows = warp_size / threads_per_row;
2762
+ const int kqsx = threadIdx.x % threads_per_row;
1967
2763
 
1968
2764
  #pragma unroll
1969
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) {
1970
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2);
2765
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2766
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1971
2767
 
1972
2768
  if (need_check) {
1973
2769
  i = min(i, i_max);
@@ -1986,44 +2782,47 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1986
2782
  const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
1987
2783
  const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
1988
2784
 
1989
- #ifdef NEW_MMA_AVAILABLE
2785
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1990
2786
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
1991
2787
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
1992
2788
  #else
1993
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
1994
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
1995
- #endif // NEW_MMA_AVAILABLE
2789
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2790
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2791
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1996
2792
  }
1997
2793
 
1998
2794
  const int ls = bxi->scales[kqsx];
1999
2795
  const float d = bxi->d;
2000
- #ifdef NEW_MMA_AVAILABLE
2001
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2002
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2796
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2797
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2798
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2003
2799
  #else
2004
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2005
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2006
- #endif // NEW_MMA_AVAILABLE
2800
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2801
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2802
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2007
2803
  }
2008
2804
  }
2009
2805
 
2010
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
2806
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
2011
2807
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2808
+ constexpr int nwarps = mmq_get_nwarps_device();
2809
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2012
2810
 
2013
- #ifdef NEW_MMA_AVAILABLE
2811
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2014
2812
  int * x_qs = (int *) x_tile;
2015
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2813
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2016
2814
  #else
2017
2815
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
2018
2816
  int * x_qs = (int *) x_tile;
2019
2817
  float * x_df = (float *) (x_qs + txs.qs);
2020
- #endif // NEW_MMA_AVAILABLE
2021
-
2022
- const int kqsx = threadIdx.x % (QI2_S/2);
2818
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2819
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
2820
+ constexpr int nrows = warp_size / threads_per_row;
2821
+ const int kqsx = threadIdx.x % threads_per_row;
2023
2822
 
2024
2823
  #pragma unroll
2025
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) {
2026
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2);
2824
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2825
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2027
2826
 
2028
2827
  if (need_check) {
2029
2828
  i = min(i, i_max);
@@ -2049,44 +2848,48 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2049
2848
  const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
2050
2849
  const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
2051
2850
 
2052
- #ifdef NEW_MMA_AVAILABLE
2851
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2053
2852
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2054
2853
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2055
2854
  #else
2056
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2057
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2058
- #endif // NEW_MMA_AVAILABLE
2855
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2856
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2857
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2059
2858
  }
2060
2859
 
2061
2860
  const int ls = bxi->scales[kqsx];
2062
2861
  const float d = bxi->d;
2063
- #ifdef NEW_MMA_AVAILABLE
2064
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2065
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2862
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2863
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2864
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2066
2865
  #else
2067
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2068
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2069
- #endif // NEW_MMA_AVAILABLE
2866
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2867
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2868
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2070
2869
  }
2071
2870
  }
2072
2871
 
2073
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2872
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2074
2873
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2874
+ constexpr int nwarps = mmq_get_nwarps_device();
2875
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2075
2876
 
2076
- #ifdef NEW_MMA_AVAILABLE
2877
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2077
2878
  int * x_qs = (int *) x_tile;
2078
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2879
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2079
2880
  #else
2080
2881
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
2081
2882
  int * x_qs = (int *) x_tile;
2082
2883
  float * x_df = (float *) (x_qs + txs.qs);
2083
- #endif // NEW_MMA_AVAILABLE
2884
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2084
2885
 
2085
- const int kqsx = threadIdx.x % (QI3_XXS/2);
2886
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
2887
+ constexpr int nrows = warp_size / threads_per_row;
2888
+ const int kqsx = threadIdx.x % threads_per_row;
2086
2889
 
2087
2890
  #pragma unroll
2088
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) {
2089
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2);
2891
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2892
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2090
2893
 
2091
2894
  if (need_check) {
2092
2895
  i = min(i, i_max);
@@ -2107,42 +2910,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2107
2910
  const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
2108
2911
  const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
2109
2912
 
2110
- #ifdef NEW_MMA_AVAILABLE
2913
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2111
2914
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
2112
2915
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
2113
2916
  #else
2114
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2115
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2116
- #endif // NEW_MMA_AVAILABLE
2917
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2918
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2919
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2117
2920
  }
2118
2921
 
2119
2922
  const int ls = aux32 >> 28;
2120
2923
  const float d = bxi->d;
2121
- #ifdef NEW_MMA_AVAILABLE
2122
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2924
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2925
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2123
2926
  #else
2124
- x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2125
- #endif // NEW_MMA_AVAILABLE
2927
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2928
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2126
2929
  }
2127
2930
  }
2128
2931
 
2129
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2932
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2130
2933
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2934
+ constexpr int nwarps = mmq_get_nwarps_device();
2935
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2131
2936
 
2132
- #ifdef NEW_MMA_AVAILABLE
2937
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2133
2938
  int * x_qs = (int *) x_tile;
2134
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2939
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2135
2940
  #else
2136
2941
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2137
2942
  int * x_qs = (int *) x_tile;
2138
2943
  float * x_df = (float *) (x_qs + txs.qs);
2139
- #endif // NEW_MMA_AVAILABLE
2944
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2140
2945
 
2141
- const int kqsx = threadIdx.x % (QI3_S/2);
2946
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
2947
+ constexpr int nrows = warp_size / threads_per_row;
2948
+ const int kqsx = threadIdx.x % threads_per_row;
2142
2949
 
2143
2950
  #pragma unroll
2144
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) {
2145
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2);
2951
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2952
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2146
2953
 
2147
2954
  if (need_check) {
2148
2955
  i = min(i, i_max);
@@ -2170,42 +2977,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2170
2977
  const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2171
2978
  const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2172
2979
 
2173
- #ifdef NEW_MMA_AVAILABLE
2980
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2174
2981
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
2175
2982
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
2176
2983
  #else
2177
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l;
2178
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h;
2179
- #endif // NEW_MMA_AVAILABLE
2984
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
2985
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
2986
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2180
2987
  }
2181
2988
 
2182
2989
  const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
2183
2990
  const float d = bxi->d;
2184
- #ifdef NEW_MMA_AVAILABLE
2185
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2991
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2992
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2186
2993
  #else
2187
- x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d;
2188
- #endif // NEW_MMA_AVAILABLE
2994
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
2995
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2189
2996
  }
2190
2997
  }
2191
2998
 
2192
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
2999
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
2193
3000
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
3001
+ constexpr int nwarps = mmq_get_nwarps_device();
3002
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2194
3003
 
2195
- #ifdef NEW_MMA_AVAILABLE
3004
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2196
3005
  int * x_qs = (int *) x_tile;
2197
- half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
3006
+ half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
2198
3007
  #else
2199
3008
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2200
3009
  int * x_qs = (int *) x_tile;
2201
3010
  half2 * x_ds = (half2 *) (x_qs + txs.qs);
2202
- #endif // NEW_MMA_AVAILABLE
3011
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2203
3012
 
2204
- const int kqsx = threadIdx.x % QI1_S;
3013
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
3014
+ constexpr int nrows = warp_size / threads_per_row;
3015
+ const int kqsx = threadIdx.x % threads_per_row;
2205
3016
 
2206
3017
  #pragma unroll
2207
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) {
2208
- int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S;
3018
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
3019
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2209
3020
 
2210
3021
  if (need_check) {
2211
3022
  i = min(i, i_max);
@@ -2225,66 +3036,71 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2225
3036
  const int grid0 = (grid >> 0) & 0x0F0F0F0F;
2226
3037
  const int grid1 = (grid >> 4) & 0x0F0F0F0F;
2227
3038
 
2228
- #ifdef NEW_MMA_AVAILABLE
3039
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2229
3040
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
2230
3041
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
2231
3042
  #else
2232
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0;
2233
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1;
2234
- #endif // NEW_MMA_AVAILABLE
3043
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
3044
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
3045
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2235
3046
  }
2236
3047
 
2237
3048
  const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
2238
3049
  const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
2239
3050
 
2240
- #ifdef NEW_MMA_AVAILABLE
2241
- x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
3051
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3052
+ x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
2242
3053
  #else
2243
- x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
2244
- #endif // NEW_MMA_AVAILABLE
3054
+ x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
3055
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2245
3056
  }
2246
3057
  }
2247
3058
 
2248
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
3059
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
2249
3060
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
3061
+ constexpr int nwarps = mmq_get_nwarps_device();
3062
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2250
3063
 
2251
- #ifdef NEW_MMA_AVAILABLE
3064
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2252
3065
  int * x_qs = (int *) x_tile;
2253
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
3066
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2254
3067
  #else
2255
3068
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2256
3069
  int * x_qs = (int *) x_tile;
2257
3070
  float * x_df = (float *) (x_qs + txs.qs);
2258
- #endif // NEW_MMA_AVAILABLE
3071
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2259
3072
 
2260
- const int kbx = 0; // threadIdx.x / QI4_XS
2261
- const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
3073
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
3074
+ constexpr int nrows = warp_size / threads_per_row;
3075
+ const int kqsx = threadIdx.x % threads_per_row;
2262
3076
 
2263
3077
  #pragma unroll
2264
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
2265
- int i = i0 + threadIdx.y;
3078
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
3079
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2266
3080
 
2267
3081
  if (need_check) {
2268
3082
  i = min(i, i_max);
2269
3083
  }
2270
3084
 
2271
- const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
3085
+ const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
2272
3086
 
2273
3087
  const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2274
- const int2 v = get_int_from_table_16(aux_q4);
2275
- const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2276
- #ifdef NEW_MMA_AVAILABLE
3088
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
3089
+ const int k0 = 8 * (kqsx / 4) + kqsx % 4;
3090
+
3091
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2277
3092
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2278
3093
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2279
3094
  #else
2280
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2281
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2282
- #endif // NEW_MMA_AVAILABLE
3095
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
3096
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
3097
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2283
3098
  }
2284
3099
 
3100
+ constexpr int rows_per_warp = warp_size / 8;
2285
3101
  #pragma unroll
2286
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
2287
- int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
3102
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
3103
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4);
2288
3104
 
2289
3105
  if (need_check) {
2290
3106
  i = min(i, i_max);
@@ -2297,18 +3113,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2297
3113
  const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
2298
3114
  | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
2299
3115
 
2300
- #ifdef NEW_MMA_AVAILABLE
2301
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
3116
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3117
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
2302
3118
  #else
2303
- x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2304
- #endif // NEW_MMA_AVAILABLE
3119
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
3120
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2305
3121
  }
2306
3122
  }
2307
3123
 
2308
- template<int mmq_x, int mmq_y, int nwarps, bool need_check>
3124
+ template<int mmq_x, int mmq_y, bool need_check>
2309
3125
  static __device__ __forceinline__ void mmq_write_back_dp4a(
2310
3126
  const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
2311
3127
  const int stride, const int i_max, const int j_max) {
3128
+ constexpr int nwarps = mmq_get_nwarps_device();
3129
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3130
+
2312
3131
  #pragma unroll
2313
3132
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2314
3133
  const int j = j0 + threadIdx.y;
@@ -2318,32 +3137,42 @@ static __device__ __forceinline__ void mmq_write_back_dp4a(
2318
3137
  }
2319
3138
 
2320
3139
  #pragma unroll
2321
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
3140
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2322
3141
  const int i = i0 + threadIdx.x;
2323
3142
 
2324
3143
  if (need_check && i > i_max) {
2325
3144
  continue;
2326
3145
  }
2327
3146
 
2328
- dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
3147
+ dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
2329
3148
  }
2330
3149
  }
2331
3150
  }
2332
3151
 
2333
- template<int mmq_x, int mmq_y, int nwarps, bool need_check>
3152
+ template<ggml_type type, int mmq_x, int mmq_y, bool need_check>
2334
3153
  static __device__ __forceinline__ void mmq_write_back_mma(
2335
3154
  const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
2336
3155
  const int stride, const int i_max, const int j_max) {
2337
- typedef tile<16, 8, int> tile_C;
2338
3156
 
2339
3157
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
3158
+ constexpr int nwarps = mmq_get_nwarps_device();
3159
+
3160
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3161
+ constexpr int tileC_IJ = mmq_get_granularity_device(0);
3162
+ typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;
3163
+ constexpr int rows_per_warp = granularity;
3164
+ #else
3165
+ typedef tile<16, 8, int> tile_C;
2340
3166
  constexpr int rows_per_warp = 2 * granularity;
3167
+ #endif // defined(AMD_MFMA_AVAILABLE)
2341
3168
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2342
3169
 
2343
3170
  const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
2344
- #ifdef NEW_MMA_AVAILABLE
3171
+ #if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2345
3172
  static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
2346
- #endif // NEW_MMA_AVAILABLE
3173
+ #else
3174
+ GGML_UNUSED(nwarps);
3175
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2347
3176
 
2348
3177
  #pragma unroll
2349
3178
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
@@ -2371,188 +3200,212 @@ static __device__ __forceinline__ void mmq_write_back_mma(
2371
3200
 
2372
3201
  // -------------------------------------------------------------------------------------------------------------------------------------
2373
3202
 
2374
- template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
3203
+ template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
2375
3204
  struct mmq_type_traits;
2376
3205
 
2377
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2378
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
3206
+ template <int mmq_x, int mmq_y, bool need_check>
3207
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
2379
3208
  static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
2380
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
2381
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>;
2382
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3209
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, need_check>;
3210
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
3211
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;
2383
3212
  };
2384
3213
 
2385
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2386
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
3214
+ template <int mmq_x, int mmq_y, bool need_check>
3215
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
2387
3216
  static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
2388
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
2389
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2390
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3217
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, need_check>;
3218
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3219
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>;
2391
3220
  };
2392
3221
 
2393
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2394
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
3222
+ template <int mmq_x, int mmq_y, bool need_check>
3223
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
2395
3224
  static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
2396
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
2397
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2398
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3225
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, need_check>;
3226
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3227
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2399
3228
  };
2400
3229
 
2401
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2402
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
3230
+ template <int mmq_x, int mmq_y, bool need_check>
3231
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
2403
3232
  static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
2404
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
2405
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2406
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3233
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, need_check>;
3234
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3235
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
2407
3236
  };
2408
3237
 
2409
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2410
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
3238
+ template <int mmq_x, int mmq_y, bool need_check>
3239
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
2411
3240
  static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
2412
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
2413
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2414
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3241
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, need_check>;
3242
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3243
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3244
+ };
3245
+
3246
+ template <int mmq_x, int mmq_y, bool need_check>
3247
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
3248
+ static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
3249
+ #ifdef BLACKWELL_MMA_AVAILABLE
3250
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
3251
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
3252
+ #else
3253
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
3254
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3255
+ #endif // BLACKWELL_MMA_AVAILABLE
3256
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2415
3257
  };
2416
3258
 
2417
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2418
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
3259
+ template <int mmq_x, int mmq_y, bool need_check>
3260
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
2419
3261
  static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
2420
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
2421
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
2422
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3262
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, need_check>;
3263
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;
3264
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>;
2423
3265
  };
2424
3266
 
2425
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2426
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
3267
+ template <int mmq_x, int mmq_y, bool need_check>
3268
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
2427
3269
  static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
2428
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
2429
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
2430
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3270
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, need_check>;
3271
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3272
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>;
2431
3273
  };
2432
3274
 
2433
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2434
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
3275
+ template <int mmq_x, int mmq_y, bool need_check>
3276
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
2435
3277
  static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
2436
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
2437
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2438
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3278
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, need_check>;
3279
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3280
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>;
2439
3281
  };
2440
3282
 
2441
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2442
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
3283
+ template <int mmq_x, int mmq_y, bool need_check>
3284
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
2443
3285
  static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
2444
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
2445
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2446
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3286
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, need_check>;
3287
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3288
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>;
2447
3289
  };
2448
3290
 
2449
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2450
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
3291
+ template <int mmq_x, int mmq_y, bool need_check>
3292
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
2451
3293
  static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
2452
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
2453
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
2454
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3294
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, need_check>;
3295
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;
3296
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>;
2455
3297
  };
2456
3298
 
2457
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2458
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
3299
+ template <int mmq_x, int mmq_y, bool need_check>
3300
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
2459
3301
  static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
2460
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, nwarps, need_check>;
2461
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2462
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3302
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, need_check>;
3303
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3304
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2463
3305
  };
2464
3306
 
2465
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2466
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> {
3307
+ template <int mmq_x, int mmq_y, bool need_check>
3308
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
2467
3309
  static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
2468
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, nwarps, need_check>;
2469
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
2470
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3310
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, need_check>;
3311
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3312
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
2471
3313
  };
2472
3314
 
2473
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2474
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> {
3315
+ template <int mmq_x, int mmq_y, bool need_check>
3316
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
2475
3317
  static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
2476
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, nwarps, need_check>;
2477
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
2478
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3318
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, need_check>;
3319
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3320
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
2479
3321
  };
2480
3322
 
2481
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2482
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> {
3323
+ template <int mmq_x, int mmq_y, bool need_check>
3324
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
2483
3325
  static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
2484
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, nwarps, need_check>;
2485
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2486
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3326
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, need_check>;
3327
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3328
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2487
3329
  };
2488
3330
 
2489
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2490
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> {
3331
+ template <int mmq_x, int mmq_y, bool need_check>
3332
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
2491
3333
  static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
2492
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, nwarps, need_check>;
2493
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2494
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3334
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, need_check>;
3335
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3336
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2495
3337
  };
2496
3338
 
2497
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2498
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> {
3339
+ template <int mmq_x, int mmq_y, bool need_check>
3340
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
2499
3341
  static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
2500
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, nwarps, need_check>;
2501
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2502
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3342
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>;
3343
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3344
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
2503
3345
  };
2504
3346
 
2505
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2506
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
3347
+ template <int mmq_x, int mmq_y, bool need_check>
3348
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
2507
3349
  static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
2508
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
2509
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2510
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3350
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, need_check>;
3351
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3352
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2511
3353
  };
2512
3354
 
2513
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2514
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
3355
+ template <int mmq_x, int mmq_y, bool need_check>
3356
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
2515
3357
  static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
2516
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
2517
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2518
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3358
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, need_check>;
3359
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3360
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2519
3361
  };
2520
3362
 
2521
- template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
3363
+ template <ggml_type type, int mmq_x, bool need_check, bool fixup>
2522
3364
  static __device__ __forceinline__ void mul_mat_q_process_tile(
2523
3365
  const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
2524
3366
  const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2525
3367
  const int stride_row_x, const int ncols_y, const int stride_col_dst,
2526
3368
  const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
2527
3369
 
3370
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3371
+ constexpr int nwarps = mmq_get_nwarps_device();
2528
3372
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2529
3373
  constexpr int mmq_y = get_mmq_y_device();
2530
- constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
3374
+ constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles;
2531
3375
 
2532
3376
  extern __shared__ int data_mul_mat_q[];
2533
3377
  int * tile_y = data_mul_mat_q + mmq_x;
2534
- int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
3378
+ int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
3379
+
3380
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
3381
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
3382
+ constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
3383
+ #else
3384
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
3385
+ constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
3386
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2535
3387
 
2536
- #ifdef NEW_MMA_AVAILABLE
2537
- constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
2538
- constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
3388
+ #if defined(BLACKWELL_MMA_AVAILABLE)
3389
+ // FP4 tile stores 8 blocks
3390
+ constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
2539
3391
  #else
2540
- constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
2541
- constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
2542
- #endif // NEW_MMA_AVAILABLE
3392
+ constexpr int ne_block = 4 * QK8_1;
3393
+ #endif // defined(BLACKWELL_MMA_AVAILABLE)
3394
+
3395
+ constexpr int ITER_K = get_iter_k(type);
3396
+ constexpr int blocks_per_iter = ITER_K / qk;
2543
3397
 
2544
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3398
+ float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
2545
3399
 
2546
- float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
3400
+ constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
2547
3401
 
2548
3402
  for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
2549
3403
  load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
2550
-
2551
3404
  {
2552
- const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
3405
+ const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
2553
3406
  #pragma unroll
2554
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
2555
- int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3407
+ for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
3408
+ int l = l0 + threadIdx.y*warp_size + threadIdx.x;
2556
3409
 
2557
3410
  tile_y[l] = by0[l];
2558
3411
  }
@@ -2565,10 +3418,10 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
2565
3418
  __syncthreads();
2566
3419
 
2567
3420
  {
2568
- const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
3421
+ const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
2569
3422
  #pragma unroll
2570
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
2571
- int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3423
+ for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
3424
+ int l = l0 + threadIdx.y*warp_size + threadIdx.x;
2572
3425
 
2573
3426
  tile_y[l] = by0[l];
2574
3427
  }
@@ -2576,7 +3429,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
2576
3429
 
2577
3430
  __syncthreads();
2578
3431
 
2579
- vec_dot(tile_x, tile_y, sum, WARP_SIZE);
3432
+ vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K);
2580
3433
 
2581
3434
  __syncthreads();
2582
3435
  }
@@ -2591,24 +3444,25 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
2591
3444
 
2592
3445
  // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
2593
3446
 
2594
- template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2595
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
3447
+ template <ggml_type type, int mmq_x, bool need_check>
3448
+ #if defined(GGML_USE_HIP)
2596
3449
  #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2597
- __launch_bounds__(WARP_SIZE*nwarps, 2)
3450
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
2598
3451
  #endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2599
3452
  #else
2600
3453
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
2601
- __launch_bounds__(WARP_SIZE*nwarps, 1)
3454
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1)
2602
3455
  #else
2603
- __launch_bounds__(WARP_SIZE*nwarps, 2)
3456
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
2604
3457
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
2605
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
3458
+ #endif // defined(GGML_USE_HIP)
2606
3459
  static __global__ void mul_mat_q(
2607
3460
  const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
2608
3461
  const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2609
3462
  const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
2610
3463
  const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
2611
- const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
3464
+ const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
3465
+ const int ncols_max) {
2612
3466
 
2613
3467
  // Skip unused template specializations for faster compilation:
2614
3468
  if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
@@ -2616,10 +3470,13 @@ static __global__ void mul_mat_q(
2616
3470
  return;
2617
3471
  }
2618
3472
 
3473
+ constexpr int nwarps = mmq_get_nwarps_device();
3474
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3475
+
2619
3476
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2620
3477
  constexpr int mmq_y = get_mmq_y_device();
2621
3478
 
2622
- const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
3479
+ const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
2623
3480
  const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
2624
3481
 
2625
3482
  // Initialize the ids for writing back data with just the index.
@@ -2627,10 +3484,10 @@ static __global__ void mul_mat_q(
2627
3484
  // For MoE the correct indices are loaded from ids_dst.
2628
3485
  extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
2629
3486
  #pragma unroll
2630
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2631
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3487
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3488
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
2632
3489
 
2633
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
3490
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
2634
3491
  break;
2635
3492
  }
2636
3493
 
@@ -2638,8 +3495,8 @@ static __global__ void mul_mat_q(
2638
3495
  }
2639
3496
  __syncthreads();
2640
3497
 
2641
- // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
2642
- #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3498
+ // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
3499
+ #if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
2643
3500
  {
2644
3501
  const int wt = blockIdx.z / nchannels_y;
2645
3502
  const int zt = blockIdx.z - wt*nchannels_y;
@@ -2667,10 +3524,10 @@ static __global__ void mul_mat_q(
2667
3524
 
2668
3525
  // __syncthreads(); // There is no previous tile that could cause a race condition.
2669
3526
  #pragma unroll
2670
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2671
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3527
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3528
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
2672
3529
 
2673
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
3530
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
2674
3531
  break;
2675
3532
  }
2676
3533
 
@@ -2688,15 +3545,17 @@ static __global__ void mul_mat_q(
2688
3545
  const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
2689
3546
 
2690
3547
  constexpr bool fixup = false;
2691
- mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
3548
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
2692
3549
  (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
2693
3550
  tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
2694
3551
  return;
2695
3552
  }
2696
- #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3553
+ #endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3554
+
3555
+ constexpr int ITER_K = get_iter_k(type);
2697
3556
 
2698
3557
  const int64_t blocks_per_ne00 = ncols_x / qk;
2699
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3558
+ constexpr int blocks_per_iter = ITER_K / qk;
2700
3559
 
2701
3560
  // kbc == k block continuous, current index in continuous ijk space.
2702
3561
  int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
@@ -2745,10 +3604,10 @@ static __global__ void mul_mat_q(
2745
3604
 
2746
3605
  __syncthreads();
2747
3606
  #pragma unroll
2748
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2749
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3607
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3608
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
2750
3609
 
2751
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
3610
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
2752
3611
  break;
2753
3612
  }
2754
3613
 
@@ -2757,7 +3616,7 @@ static __global__ void mul_mat_q(
2757
3616
  __syncthreads();
2758
3617
  }
2759
3618
 
2760
- offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
3619
+ offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
2761
3620
  offset_dst += it*mmq_y;
2762
3621
 
2763
3622
  const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@@ -2766,7 +3625,7 @@ static __global__ void mul_mat_q(
2766
3625
  const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
2767
3626
 
2768
3627
  constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
2769
- mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
3628
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
2770
3629
  (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
2771
3630
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2772
3631
 
@@ -2812,10 +3671,10 @@ static __global__ void mul_mat_q(
2812
3671
  // The memory layout for the fixup buffer is always contiguous, therefore reset ids:
2813
3672
  __syncthreads();
2814
3673
  #pragma unroll
2815
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2816
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3674
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3675
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
2817
3676
 
2818
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
3677
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
2819
3678
  break;
2820
3679
  }
2821
3680
 
@@ -2824,7 +3683,7 @@ static __global__ void mul_mat_q(
2824
3683
  __syncthreads();
2825
3684
  }
2826
3685
 
2827
- offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
3686
+ offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
2828
3687
  offset_dst += it*mmq_y;
2829
3688
 
2830
3689
  const int tile_x_max_i = nrows_x - it*mmq_y - 1;
@@ -2833,25 +3692,31 @@ static __global__ void mul_mat_q(
2833
3692
  const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
2834
3693
 
2835
3694
  constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
2836
- mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
3695
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
2837
3696
  (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
2838
3697
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2839
3698
  }
2840
3699
 
2841
3700
 
2842
- template <ggml_type type, int mmq_x, int nwarps, bool need_check>
3701
+ template <ggml_type type, int mmq_x, bool need_check>
2843
3702
  static __global__ void mul_mat_q_stream_k_fixup(
2844
3703
  const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
2845
3704
  const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
2846
- const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
3705
+ const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
3706
+ const int ncols_max) {
2847
3707
  constexpr int mmq_y = get_mmq_y_device();
2848
3708
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2849
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3709
+ constexpr int ITER_K = get_iter_k(type);
3710
+
3711
+ constexpr int blocks_per_iter = ITER_K / qk;
2850
3712
  const int64_t blocks_per_ne00 = ncols_x / qk;
2851
3713
 
2852
- float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
3714
+ constexpr int nwarps = mmq_get_nwarps_device();
3715
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3716
+
3717
+ float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
2853
3718
 
2854
- const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
3719
+ const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
2855
3720
  const int nty = (nrows_x + mmq_y - 1) / mmq_y;
2856
3721
 
2857
3722
  const int bidx0 = blockIdx.x;
@@ -2893,10 +3758,10 @@ static __global__ void mul_mat_q_stream_k_fixup(
2893
3758
  const int j = j0 + threadIdx.y;
2894
3759
 
2895
3760
  #pragma unroll
2896
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
3761
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2897
3762
  const int i = i0 + threadIdx.x;
2898
3763
 
2899
- sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
3764
+ sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
2900
3765
  }
2901
3766
  }
2902
3767
 
@@ -2937,14 +3802,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
2937
3802
  }
2938
3803
 
2939
3804
  #pragma unroll
2940
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
3805
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2941
3806
  const int i = i0 + threadIdx.x;
2942
3807
 
2943
3808
  if (need_check && i > i_max) {
2944
3809
  continue;
2945
3810
  }
2946
3811
 
2947
- dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
3812
+ dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
2948
3813
  }
2949
3814
  }
2950
3815
  return;
@@ -2955,8 +3820,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
2955
3820
  const int col_high = expert_bounds[zt + 1];
2956
3821
  const int col_diff = col_high - col_low;
2957
3822
 
2958
- for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) {
2959
- ids_dst_shared[j] = ids_dst[col_low + j];
3823
+ for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
3824
+ ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
2960
3825
  }
2961
3826
  __syncthreads();
2962
3827
 
@@ -2975,14 +3840,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
2975
3840
  }
2976
3841
 
2977
3842
  #pragma unroll
2978
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
3843
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2979
3844
  const int i = i0 + threadIdx.x;
2980
3845
 
2981
3846
  if (need_check && i > i_max) {
2982
3847
  continue;
2983
3848
  }
2984
3849
 
2985
- dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
3850
+ dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
2986
3851
  }
2987
3852
  }
2988
3853
  }
@@ -2992,17 +3857,17 @@ struct mmq_args {
2992
3857
  int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
2993
3858
  int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
2994
3859
  int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
2995
- bool use_stream_k;
3860
+ bool use_stream_k; int64_t ncols_max;
2996
3861
  };
2997
3862
 
2998
3863
  template<ggml_type type>
2999
- static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) {
3864
+ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
3000
3865
  const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
3001
3866
  const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
3002
3867
  const size_t nbs_ids = mmq_x*sizeof(int);
3003
- const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3004
- const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
3005
- return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
3868
+ const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3869
+ const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
3870
+ return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
3006
3871
  }
3007
3872
 
3008
3873
  template <ggml_type type, int mmq_x>
@@ -3010,23 +3875,19 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3010
3875
  const int id = ggml_cuda_get_device();
3011
3876
  const int cc = ggml_cuda_info().devices[id].cc;
3012
3877
  const int nsm = ggml_cuda_info().devices[id].nsm;
3878
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
3879
+ const int nwarps = mmq_get_nwarps_host(cc, warp_size);
3013
3880
  const int mmq_y = get_mmq_y_host(cc);
3014
3881
 
3015
- const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
3882
+ const dim3 block_dims(warp_size, nwarps, 1);
3016
3883
 
3017
- const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
3884
+ const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);
3018
3885
 
3019
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3020
- static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
3021
- if (!shared_memory_limit_raised[id]) {
3022
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3023
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3024
- shared_memory_limit_raised[id] = true;
3025
- }
3026
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3886
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared);
3887
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
3027
3888
 
3028
3889
  const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
3029
- const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
3890
+ const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
3030
3891
  const int ntzw = args.nchannels_y * args.nsamples_y;
3031
3892
  const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
3032
3893
 
@@ -3038,18 +3899,20 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3038
3899
  if (!args.use_stream_k) {
3039
3900
  if (args.nrows_x % mmq_y == 0) {
3040
3901
  constexpr bool need_check = false;
3041
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3902
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3042
3903
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3043
3904
  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3044
3905
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3045
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3906
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3907
+ args.ncols_max);
3046
3908
  } else {
3047
3909
  constexpr bool need_check = true;
3048
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3910
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3049
3911
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3050
3912
  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3051
3913
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3052
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3914
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3915
+ args.ncols_max);
3053
3916
  }
3054
3917
  return;
3055
3918
  }
@@ -3065,44 +3928,48 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3065
3928
 
3066
3929
  if (args.nrows_x % mmq_y == 0) {
3067
3930
  constexpr bool need_check = false;
3068
-
3069
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3931
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3070
3932
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3071
3933
  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3072
3934
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3073
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3935
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3936
+ args.ncols_max);
3074
3937
 
3075
3938
  if (!fixup_needed) {
3076
3939
  return;
3077
3940
  }
3078
3941
 
3079
- mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3942
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3080
3943
  (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3081
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
3944
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3945
+ args.ncols_max);
3082
3946
  } else {
3083
3947
  constexpr bool need_check = true;
3084
-
3085
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3948
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3086
3949
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3087
3950
  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3088
3951
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3089
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3952
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3953
+ args.ncols_max);
3090
3954
 
3091
3955
  if (!fixup_needed) {
3092
3956
  return;
3093
3957
  }
3094
3958
 
3095
- mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3959
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3096
3960
  (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3097
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
3961
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3962
+ args.ncols_max);
3098
3963
  }
3099
3964
  }
3100
3965
 
3101
3966
  template <ggml_type type>
3102
3967
  void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
3103
- const int id = ggml_cuda_get_device();
3104
- const int cc = ggml_cuda_info().devices[id].cc;
3105
- const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
3968
+ const int id = ggml_cuda_get_device();
3969
+ const int cc = ggml_cuda_info().devices[id].cc;
3970
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
3971
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
3972
+ const int nwarps = mmq_get_nwarps_host(cc, warp_size);
3106
3973
 
3107
3974
  const int mmq_x_max = get_mmq_x_max_host(cc);
3108
3975
  const int mmq_y = get_mmq_y_host(cc);
@@ -3113,11 +3980,11 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
3113
3980
  for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
3114
3981
  const int granularity = mmq_get_granularity_host(mmq_x, cc);
3115
3982
 
3116
- if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc) > smpbo) {
3983
+ if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {
3117
3984
  continue;
3118
3985
  }
3119
3986
 
3120
- const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
3987
+ const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
3121
3988
 
3122
3989
  if (ntiles_x < ntiles_x_best) {
3123
3990
  mmq_x_best = mmq_x;
@@ -3189,6 +4056,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
3189
4056
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
3190
4057
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
3191
4058
  extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
4059
+ extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
3192
4060
  extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
3193
4061
  extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
3194
4062
  extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
@@ -3214,4 +4082,4 @@ void ggml_cuda_op_mul_mat_q(
3214
4082
  const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
3215
4083
  const int64_t src1_padded_row_size, cudaStream_t stream);
3216
4084
 
3217
- bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);
4085
+ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);