whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -58,6 +58,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
58
58
  return MMQ_Q8_1_DS_LAYOUT_DS4;
59
59
  case GGML_TYPE_Q8_0:
60
60
  return MMQ_Q8_1_DS_LAYOUT_D4;
61
+ case GGML_TYPE_MXFP4:
62
+ return MMQ_Q8_1_DS_LAYOUT_D4;
61
63
  case GGML_TYPE_Q2_K:
62
64
  return MMQ_Q8_1_DS_LAYOUT_D2S6;
63
65
  case GGML_TYPE_Q3_K:
@@ -90,7 +92,7 @@ struct tile_x_sizes {
90
92
  };
91
93
 
92
94
  static int get_mmq_x_max_host(const int cc) {
93
- return new_mma_available(cc) ? 128 :
95
+ return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 :
94
96
  GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
95
97
  #ifdef GGML_CUDA_FORCE_MMQ
96
98
  128 : 64;
@@ -100,13 +102,13 @@ static int get_mmq_x_max_host(const int cc) {
100
102
  }
101
103
 
102
104
  static constexpr __device__ int get_mmq_x_max_device() {
103
- #ifdef NEW_MMA_AVAILABLE
105
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
104
106
  return 128;
105
- #else // NEW_MMA_AVAILABLE
107
+ #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
106
108
 
107
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
108
- return 128;
109
- #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
109
+ #if defined(GGML_USE_HIP)
110
+ return 64;
111
+ #else // defined(GGML_USE_HIP)
110
112
 
111
113
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
112
114
  #ifdef GGML_CUDA_FORCE_MMQ
@@ -115,12 +117,11 @@ static constexpr __device__ int get_mmq_x_max_device() {
115
117
  return MMQ_DP4A_MAX_BATCH_SIZE;
116
118
  #endif // GGML_CUDA_FORCE_MMQ
117
119
  #else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
118
-
119
120
  return 64;
120
121
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
121
122
 
122
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
123
- #endif // NEW_MMA_AVAILABLE
123
+ #endif // defined(GGML_USE_HIP)
124
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
124
125
  }
125
126
 
126
127
  static int get_mmq_y_host(const int cc) {
@@ -129,7 +130,7 @@ static int get_mmq_y_host(const int cc) {
129
130
  }
130
131
 
131
132
  static constexpr __device__ int get_mmq_y_device() {
132
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
133
+ #if defined(GGML_USE_HIP)
133
134
  #if defined(RDNA1)
134
135
  return 64;
135
136
  #else
@@ -141,19 +142,28 @@ static constexpr __device__ int get_mmq_y_device() {
141
142
  #else
142
143
  return 64;
143
144
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
144
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
145
+ #endif // defined(GGML_USE_HIP)
145
146
  }
146
147
 
147
- #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
148
- #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
149
- #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
150
- #define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0}
151
- #define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
152
- #define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
153
- #define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8}
154
- #define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
155
- #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
156
- #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
148
+ // Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
149
+ // The K dimension of the tiles has either,
150
+ // 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
151
+ // 32 bit elements for the quantized data (does not include scales).
152
+ // In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
153
+ // The final tile size in K direction is padded to avoid shared memory bank conflicts,
154
+ // in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
155
+ #define MMQ_TILE_NE_K 32
156
+
157
+ #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0}
158
+ #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0}
159
+ #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}
160
+ #define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0}
161
+ #define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0}
162
+ #define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0}
163
+ #define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
164
+ #define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
165
+ #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
166
+ #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
157
167
 
158
168
  static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
159
169
  switch (type) {
@@ -162,6 +172,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
162
172
  case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
163
173
  case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
164
174
  case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
175
+ case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
165
176
  case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
166
177
  case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
167
178
  case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
@@ -179,11 +190,11 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
179
190
  }
180
191
  }
181
192
 
182
- #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
183
- #define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
184
- #define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4)
185
- #define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4)
186
- #define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
193
+ #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
194
+ #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
195
+ #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
196
+ #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
197
+ #define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
187
198
 
188
199
  static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
189
200
  static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
@@ -198,6 +209,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
198
209
  case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
199
210
  case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
200
211
  case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
212
+ case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
201
213
  case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
202
214
  case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
203
215
  case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@@ -215,42 +227,76 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
215
227
  }
216
228
  }
217
229
 
218
- #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
230
+ // block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
231
+ #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
219
232
 
220
233
  static int mmq_get_granularity_host(const int mmq_x, const int cc) {
221
- return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
234
+ if (amd_mfma_available(cc)) {
235
+ return mmq_x >= 128 ? 32 : 16;
236
+ } else if (turing_mma_available(cc) && mmq_x >= 48) {
237
+ return 16;
238
+ } else {
239
+ return 8;
240
+ }
222
241
  }
223
242
 
224
- #ifdef NEW_MMA_AVAILABLE
243
+ #if defined(AMD_MFMA_AVAILABLE)
244
+ static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
245
+ return mmq_x >= 128 ? 32 : 16;
246
+ }
247
+ #elif defined(TURING_MMA_AVAILABLE)
225
248
  static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
226
249
  return mmq_x >= 48 ? 16 : 8;
227
250
  }
228
251
  #else
229
- static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
252
+ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) {
230
253
  return 8;
231
254
  }
232
- #endif // NEW_MMA_AVAILABLE
255
+ #endif // AMD_MFMA_AVAILABLE
256
+
257
+ #if defined(GGML_USE_HIP)
258
+ static int mmq_get_nwarps_host(const int cc, const int warp_size) {
259
+ return amd_mfma_available(cc) ? 8 : 256/warp_size;
260
+ }
261
+ #else
262
+ static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
263
+ return 256/warp_size;
264
+ }
265
+ #endif // (GGML_USE_HIP)
266
+
267
+ static constexpr __device__ int mmq_get_nwarps_device() {
268
+ #if defined(AMD_MFMA_AVAILABLE)
269
+ return 8;
270
+ #else
271
+ return 256/ggml_cuda_get_physical_warp_size();
272
+ #endif // AMD_MFMA_AVAILABLE
273
+ }
233
274
 
234
275
  // ------------------------------------------------------------
235
276
 
236
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
277
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
237
278
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
279
+ constexpr int nwarps = mmq_get_nwarps_device();
280
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
238
281
 
239
- #ifdef NEW_MMA_AVAILABLE
282
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
240
283
  int * x_qs = (int *) x_tile;
241
- float * x_df = (float *) (x_qs + 2*WARP_SIZE);
284
+ float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
242
285
  #else
243
286
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
244
287
  int * x_qs = (int *) x_tile;
245
288
  float * x_df = (float *) (x_qs + txs.qs);
246
- #endif // NEW_MMA_AVAILABLE
289
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
247
290
 
248
- const int kbx = threadIdx.x / QI4_0;
249
- const int kqsx = threadIdx.x % QI4_0;
291
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
292
+ constexpr int nrows = warp_size / threads_per_row;
293
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
294
+ const int kbx = txi / QI4_0;
295
+ const int kqsx = txi % QI4_0;
250
296
 
251
297
  #pragma unroll
252
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
253
- int i = i0 + threadIdx.y;
298
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
299
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
254
300
 
255
301
  if (need_check) {
256
302
  i = min(i, i_max);
@@ -259,20 +305,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
259
305
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
260
306
  const int qs0 = get_int_b2(bxi->qs, kqsx);
261
307
 
262
- #ifdef NEW_MMA_AVAILABLE
308
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
263
309
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
264
310
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
265
311
  #else
266
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
267
- #endif // NEW_MMA_AVAILABLE
312
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
313
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
268
314
  }
269
315
 
270
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
316
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
317
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
271
318
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
272
319
 
273
320
  #pragma unroll
274
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
275
- int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
321
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
322
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
276
323
 
277
324
  if (need_check) {
278
325
  i = min(i, i_max);
@@ -280,17 +327,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
280
327
 
281
328
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
282
329
 
283
- #ifdef NEW_MMA_AVAILABLE
284
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
330
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
331
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
285
332
  #else
286
- x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
287
- #endif // NEW_MMA_AVAILABLE
333
+ x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
334
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
288
335
  }
289
336
  }
290
337
 
291
- template <int mmq_x, int mmq_y, int nwarps>
338
+ template <int mmq_x, int mmq_y>
292
339
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
293
340
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
341
+ constexpr int nwarps = mmq_get_nwarps_device();
342
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
294
343
 
295
344
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
296
345
  const int * x_qs = (const int *) x;
@@ -299,7 +348,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
299
348
  const half2 * y_ds = (const half2 *) y;
300
349
 
301
350
  // #pragma unroll
302
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
351
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
303
352
  const int k0 = k00 + k01;
304
353
 
305
354
  #pragma unroll
@@ -307,7 +356,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
307
356
  const int j = j0 + threadIdx.y;
308
357
 
309
358
  #pragma unroll
310
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
359
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
311
360
  const int i = i0 + threadIdx.x;
312
361
 
313
362
  const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
@@ -320,32 +369,37 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
320
369
  u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
321
370
  }
322
371
 
323
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
324
- (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u,
325
- x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
372
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
373
+ (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
374
+ x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
326
375
  }
327
376
  }
328
377
  }
329
378
  }
330
379
 
331
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
380
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
332
381
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
382
+ constexpr int nwarps = mmq_get_nwarps_device();
383
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
333
384
 
334
- #ifdef NEW_MMA_AVAILABLE
385
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
335
386
  int * x_qs = (int *) x_tile;
336
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
387
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
337
388
  #else
338
389
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
339
390
  int * x_qs = (int *) x_tile;
340
391
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
341
- #endif // NEW_MMA_AVAILABLE
392
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
342
393
 
343
- const int kbx = threadIdx.x / QI4_1;
344
- const int kqsx = threadIdx.x % QI4_1;
394
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
395
+ constexpr int nrows = warp_size / threads_per_row;
396
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
397
+ const int kbx = txi / QI4_1;
398
+ const int kqsx = txi % QI4_1;
345
399
 
346
400
  #pragma unroll
347
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
348
- int i = i0 + threadIdx.y;
401
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
402
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
349
403
 
350
404
  if (need_check) {
351
405
  i = min(i, i_max);
@@ -354,20 +408,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
354
408
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
355
409
  const int qs0 = get_int_b4(bxi->qs, kqsx);
356
410
 
357
- #ifdef NEW_MMA_AVAILABLE
411
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
358
412
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
359
413
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
360
414
  #else
361
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
362
- #endif // NEW_MMA_AVAILABLE
415
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
416
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
363
417
  }
364
418
 
365
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
419
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
420
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
366
421
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
367
422
 
368
423
  #pragma unroll
369
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
370
- int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
424
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
425
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
371
426
 
372
427
  if (need_check) {
373
428
  i = min(i, i_max);
@@ -375,17 +430,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
375
430
 
376
431
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
377
432
 
378
- #ifdef NEW_MMA_AVAILABLE
379
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
433
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
434
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
380
435
  #else
381
- x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
382
- #endif // NEW_MMA_AVAILABLE
436
+ x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
437
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
383
438
  }
384
439
  }
385
440
 
386
- template <int mmq_x, int mmq_y, int nwarps>
441
+ template <int mmq_x, int mmq_y>
387
442
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
388
443
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
444
+ constexpr int nwarps = mmq_get_nwarps_device();
445
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
389
446
 
390
447
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
391
448
  const int * x_qs = (const int *) x;
@@ -394,7 +451,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
394
451
  const half2 * y_ds = (const half2 *) y;
395
452
 
396
453
  // #pragma unroll
397
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
454
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
398
455
  const int k0 = k00 + k01;
399
456
 
400
457
  #pragma unroll
@@ -402,7 +459,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
402
459
  const int j = j0 + threadIdx.y;
403
460
 
404
461
  #pragma unroll
405
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
462
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
406
463
  const int i = i0 + threadIdx.x;
407
464
 
408
465
  const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
@@ -415,32 +472,37 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
415
472
  u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
416
473
  }
417
474
 
418
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
419
- (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u,
420
- x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
475
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
476
+ (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
477
+ x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
421
478
  }
422
479
  }
423
480
  }
424
481
  }
425
482
 
426
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
483
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
427
484
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
485
+ constexpr int nwarps = mmq_get_nwarps_device();
486
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
428
487
 
429
- #ifdef NEW_MMA_AVAILABLE
488
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
430
489
  int * x_qs = (int *) x_tile;
431
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
490
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
432
491
  #else
433
492
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
434
493
  int * x_qs = (int *) x_tile;
435
494
  float * x_df = (float *) (x_qs + txs.qs);
436
- #endif // NEW_MMA_AVAILABLE
495
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
437
496
 
438
- const int kbx = threadIdx.x / QI5_0;
439
- const int kqsx = threadIdx.x % QI5_0;
497
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
498
+ constexpr int nrows = warp_size / threads_per_row;
499
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
500
+ const int kbx = txi / QI5_0;
501
+ const int kqsx = txi % QI5_0;
440
502
 
441
503
  #pragma unroll
442
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
443
- int i = i0 + threadIdx.y;
504
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
505
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
444
506
 
445
507
  if (need_check) {
446
508
  i = min(i, i_max);
@@ -449,7 +511,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
449
511
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
450
512
 
451
513
  const int ql = get_int_b2(bxi->qs, kqsx);
452
- const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
514
+ const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx);
453
515
 
454
516
  int qs0 = (ql >> 0) & 0x0F0F0F0F;
455
517
  qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
@@ -465,21 +527,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
465
527
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
466
528
  qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
467
529
 
468
- #ifdef NEW_MMA_AVAILABLE
530
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
469
531
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
470
532
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
471
533
  #else
472
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
473
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
474
- #endif // NEW_MMA_AVAILABLE
534
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
535
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
536
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
475
537
  }
476
538
 
477
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
539
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
540
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
478
541
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
479
542
 
480
543
  #pragma unroll
481
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
482
- int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
544
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
545
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
483
546
 
484
547
  if (need_check) {
485
548
  i = min(i, i_max);
@@ -487,32 +550,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
487
550
 
488
551
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
489
552
 
490
- #ifdef NEW_MMA_AVAILABLE
491
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
553
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
554
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
492
555
  #else
493
- x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
494
- #endif // NEW_MMA_AVAILABLE
556
+ x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
557
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
495
558
  }
496
559
  }
497
560
 
498
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
561
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
499
562
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
563
+ constexpr int nwarps = mmq_get_nwarps_device();
564
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
500
565
 
501
- #ifdef NEW_MMA_AVAILABLE
566
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
502
567
  int * x_qs = (int *) x_tile;
503
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
568
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
504
569
  #else
505
570
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
506
571
  int * x_qs = (int *) x_tile;
507
572
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
508
- #endif // NEW_MMA_AVAILABLE
573
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
509
574
 
510
- const int kbx = threadIdx.x / QI5_1;
511
- const int kqsx = threadIdx.x % QI5_1;
575
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
576
+ constexpr int nrows = warp_size / threads_per_row;
577
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
578
+ const int kbx = txi / QI5_1;
579
+ const int kqsx = txi % QI5_1;
512
580
 
513
581
  #pragma unroll
514
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
515
- int i = i0 + threadIdx.y;
582
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
583
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
516
584
 
517
585
  if (need_check) {
518
586
  i = min(i, i_max);
@@ -521,7 +589,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
521
589
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
522
590
 
523
591
  const int ql = get_int_b4(bxi->qs, kqsx);
524
- const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
592
+ const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx);
525
593
 
526
594
  int qs0 = (ql >> 0) & 0x0F0F0F0F;
527
595
  qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
@@ -535,21 +603,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
535
603
  qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
536
604
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
537
605
 
538
- #ifdef NEW_MMA_AVAILABLE
606
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
539
607
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
540
608
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
541
609
  #else
542
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
543
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
544
- #endif // NEW_MMA_AVAILABLE
610
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
611
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
612
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
545
613
  }
546
614
 
547
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
615
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
616
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
548
617
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
549
618
 
550
619
  #pragma unroll
551
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
552
- int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
620
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
621
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
553
622
 
554
623
  if (need_check) {
555
624
  i = min(i, i_max);
@@ -557,32 +626,38 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
557
626
 
558
627
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
559
628
 
560
- #ifdef NEW_MMA_AVAILABLE
561
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
629
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
630
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
562
631
  #else
563
- x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
564
- #endif // NEW_MMA_AVAILABLE
632
+ x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
633
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
565
634
  }
566
635
  }
567
636
 
568
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
637
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
569
638
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
639
+ constexpr int nwarps = mmq_get_nwarps_device();
640
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
570
641
 
571
- #ifdef NEW_MMA_AVAILABLE
642
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
572
643
  int * x_qs = (int *) x_tile;
573
- float * x_df = (float *) (x_tile + 2*WARP_SIZE);
644
+ float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
574
645
  #else
575
646
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
576
647
  int * x_qs = (int *) x_tile;
577
648
  float * x_df = (float *) (x_qs + txs.qs);
578
- #endif // NEW_MMA_AVAILABLE
649
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
579
650
 
580
- const int kbx = threadIdx.x / QI8_0;
581
- const int kqsx = threadIdx.x % QI8_0;
651
+ // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
652
+ constexpr int threads_per_row = 32;
653
+ constexpr int nrows = warp_size / threads_per_row;
654
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
655
+ const int kbx = txi / QI8_0;
656
+ const int kqsx = txi % QI8_0;
582
657
 
583
658
  #pragma unroll
584
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
585
- int i = i0 + threadIdx.y;
659
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
660
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
586
661
 
587
662
  if (need_check) {
588
663
  i = min(i, i_max);
@@ -590,21 +665,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
590
665
 
591
666
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
592
667
 
593
- #ifdef NEW_MMA_AVAILABLE
594
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
595
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
668
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
669
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
670
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
596
671
  #else
597
- x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
598
- x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
599
- #endif // NEW_MMA_AVAILABLE
672
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
673
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
674
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
600
675
  }
601
676
 
602
- const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
677
+ constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
678
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
603
679
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
604
680
 
605
681
  #pragma unroll
606
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) {
607
- int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row;
682
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
683
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
608
684
 
609
685
  if (need_check) {
610
686
  i = min(i, i_max);
@@ -612,17 +688,84 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
612
688
 
613
689
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
614
690
 
615
- #ifdef NEW_MMA_AVAILABLE
616
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
691
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
692
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
617
693
  #else
618
- x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
619
- #endif // NEW_MMA_AVAILABLE
694
+ x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
695
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
620
696
  }
621
697
  }
622
698
 
623
- template <int mmq_x, int mmq_y, int nwarps>
699
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
700
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
701
+ constexpr int nwarps = mmq_get_nwarps_device();
702
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
703
+
704
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
705
+ int * x_qs = (int *) x_tile;
706
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
707
+ #else
708
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
709
+ int * x_qs = (int *) x_tile;
710
+ float * x_df = (float *) (x_qs + txs.qs);
711
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
712
+
713
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
714
+ constexpr int nrows = warp_size / threads_per_row;
715
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
716
+ const int kbx = txi / QI_MXFP4;
717
+ const int kqsx = txi % QI_MXFP4;
718
+
719
+ #pragma unroll
720
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
721
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
722
+
723
+ if (need_check) {
724
+ i = min(i, i_max);
725
+ }
726
+
727
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
728
+
729
+ const int aux_q4 = get_int_b1(bxi->qs, kqsx);
730
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
731
+ const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
732
+
733
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
734
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
735
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
736
+ #else
737
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
738
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
739
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
740
+ }
741
+
742
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
743
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
744
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
745
+
746
+ #pragma unroll
747
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
748
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
749
+
750
+ if (need_check) {
751
+ i = min(i, i_max);
752
+ }
753
+
754
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
755
+
756
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
757
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
758
+ #else
759
+ x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
760
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
761
+ }
762
+ }
763
+
764
+ template <int mmq_x, int mmq_y>
624
765
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
625
766
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
767
+ constexpr int nwarps = mmq_get_nwarps_device();
768
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
626
769
 
627
770
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
628
771
  const int * x_qs = (const int *) x;
@@ -631,7 +774,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
631
774
  const float * y_df = (const float *) y;
632
775
 
633
776
  // #pragma unroll
634
- for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
777
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
635
778
  const int k0 = k00 + k01;
636
779
 
637
780
  #pragma unroll
@@ -639,21 +782,76 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
639
782
  const int j = j0 + threadIdx.y;
640
783
 
641
784
  #pragma unroll
642
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
785
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
643
786
  const int i = i0 + threadIdx.x;
644
787
 
645
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
646
- (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE],
647
- x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]);
788
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
789
+ (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],
790
+ x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]);
648
791
  }
649
792
  }
650
793
  }
651
794
  }
652
795
 
653
- template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
796
+ template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
654
797
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
655
798
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
799
+ #if defined(AMD_MFMA_AVAILABLE)
800
+ typedef tile<16, 8, int> tile_A;
801
+ typedef tile<16, 8, int> tile_B;
802
+ typedef tile<16, 16, int> tile_C;
803
+
804
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
805
+ constexpr int rows_per_warp = granularity;
806
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
807
+
808
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
809
+
810
+ const int * x_qs = (const int *) x;
811
+ const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
812
+ const int * y_qs = (const int *) y + 4;
813
+ const float * y_df = (const float *) y;
814
+ const half2 * y_ds = (const half2 *) y;
656
815
 
816
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
817
+
818
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
819
+ const int k0 = k00 + k01;
820
+
821
+ tile_A A[ntx];
822
+ #pragma unroll
823
+ for (int n = 0; n < ntx; ++n) {
824
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
825
+ }
826
+
827
+ #pragma unroll
828
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
829
+ tile_B B;
830
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
831
+
832
+ float dB;
833
+ const int j = j0 + tile_C::get_j(0);
834
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
835
+ dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
836
+ } else {
837
+ dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
838
+ }
839
+
840
+ #pragma unroll
841
+ for (int n = 0; n < ntx; ++n) {
842
+ tile_C C;
843
+ mma(C, A[n], B);
844
+
845
+ #pragma unroll
846
+ for (int l = 0; l < tile_C::ne; ++l) {
847
+ const int i = i0 + n*tile_A::I + tile_C::get_i(l);
848
+ const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
849
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;
850
+ }
851
+ }
852
+ }
853
+ }
854
+ #else
657
855
  typedef tile<16, 8, int> tile_A;
658
856
  typedef tile< 8, 8, int> tile_B;
659
857
  typedef tile<16, 8, int> tile_C;
@@ -662,23 +860,23 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
662
860
  constexpr int rows_per_warp = 2 * granularity;
663
861
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
664
862
 
665
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
863
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
666
864
 
667
865
  const int * x_qs = (const int *) x;
668
- const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
866
+ const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
669
867
  const int * y_qs = (const int *) y + 4;
670
868
  const float * y_df = (const float *) y;
671
869
  const half2 * y_ds = (const half2 *) y;
672
870
 
673
- tile_A A[ntx][WARP_SIZE/QI8_0];
674
- float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
871
+ tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];
872
+ float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];
675
873
 
676
874
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
677
875
 
678
876
  #pragma unroll
679
877
  for (int n = 0; n < ntx; ++n) {
680
878
  #pragma unroll
681
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
879
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
682
880
  const int k0 = k00 + k01;
683
881
 
684
882
  load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
@@ -689,7 +887,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
689
887
  const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
690
888
 
691
889
  #pragma unroll
692
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
890
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
693
891
  const int k0 = k00 + k01;
694
892
 
695
893
  dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
@@ -700,7 +898,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
700
898
  #pragma unroll
701
899
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
702
900
  #pragma unroll
703
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
901
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
704
902
  tile_B B;
705
903
  float dB[tile_C::ne/2];
706
904
 
@@ -729,11 +927,14 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
729
927
  }
730
928
  }
731
929
  }
930
+ #endif // defined(AMD_MFMA_AVAILABLE)
732
931
  }
733
932
 
734
- template <int mmq_x, int mmq_y, int nwarps>
933
+ template <int mmq_x, int mmq_y>
735
934
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
736
935
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
936
+ constexpr int nwarps = mmq_get_nwarps_device();
937
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
737
938
 
738
939
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
739
940
  const int * x_qs = (const int *) x;
@@ -742,7 +943,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
742
943
  const half2 * y_ds = (const half2 *) y;
743
944
 
744
945
  // #pragma unroll
745
- for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
946
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
746
947
  const int k0 = k00 + k01;
747
948
 
748
949
  #pragma unroll
@@ -750,45 +951,95 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
750
951
  const int j = j0 + threadIdx.y;
751
952
 
752
953
  #pragma unroll
753
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
954
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
754
955
  const int i = i0 + threadIdx.x;
755
956
 
756
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
757
- (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
758
- x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
957
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
958
+ (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
959
+ x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
759
960
  }
760
961
  }
761
962
  }
762
963
  }
763
964
 
764
- template <int mmq_x, int mmq_y, int nwarps>
965
+ template <int mmq_x, int mmq_y>
765
966
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
766
967
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
968
+ #if defined(AMD_MFMA_AVAILABLE)
969
+ typedef tile<16, 8, int> tile_A;
970
+ typedef tile<16, 8, int> tile_B;
971
+ typedef tile<16, 16, int> tile_C;
767
972
 
768
- typedef tile<16, 8, int> tile_A;
769
- typedef tile< 8, 8, int> tile_B;
770
- typedef tile<16, 8, int> tile_C;
973
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
974
+ constexpr int rows_per_warp = granularity;
975
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
976
+
977
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
978
+
979
+ const int * x_qs = (const int *) x;
980
+ const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
981
+ const int * y_qs = (const int *) y + 4;
982
+ const half2 * y_dm = (const half2 *) y;
983
+
984
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
985
+
986
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
987
+ const int k0 = k00 + k01;
988
+
989
+ tile_A A[ntx];
990
+ #pragma unroll
991
+ for (int n = 0; n < ntx; ++n) {
992
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
993
+ }
994
+
995
+ #pragma unroll
996
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
997
+ tile_B B;
998
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
999
+
1000
+ const int j = j0 + tile_C::get_j(0);
1001
+ const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
1002
+
1003
+ #pragma unroll
1004
+ for (int n = 0; n < ntx; ++n) {
1005
+ tile_C C;
1006
+ mma(C, A[n], B);
1007
+
1008
+ #pragma unroll
1009
+ for (int l = 0; l < tile_C::ne; ++l) {
1010
+ const int i = i0 + n*tile_A::I + tile_C::get_i(l);
1011
+ float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
1012
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
1013
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;
1014
+ }
1015
+ }
1016
+ }
1017
+ }
1018
+ #else
1019
+ typedef tile<16, 8, int> tile_A;
1020
+ typedef tile< 8, 8, int> tile_B;
1021
+ typedef tile<16, 8, int> tile_C;
771
1022
 
772
1023
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
773
1024
  constexpr int rows_per_warp = 2 * granularity;
774
1025
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
775
1026
 
776
- y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
1027
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
777
1028
 
778
1029
  const int * x_qs = (const int *) x;
779
- const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
1030
+ const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
780
1031
  const int * y_qs = (const int *) y + 4;
781
1032
  const half2 * y_dm = (const half2 *) y;
782
1033
 
783
- tile_A A[ntx][WARP_SIZE/QI8_1];
784
- float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
1034
+ tile_A A[ntx][MMQ_TILE_NE_K/QI8_1];
1035
+ float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
785
1036
 
786
1037
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
787
1038
 
788
1039
  #pragma unroll
789
1040
  for (int n = 0; n < ntx; ++n) {
790
1041
  #pragma unroll
791
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1042
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
792
1043
  const int k0 = k00 + k01;
793
1044
 
794
1045
  load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
@@ -799,7 +1050,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
799
1050
  const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
800
1051
 
801
1052
  #pragma unroll
802
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1053
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
803
1054
  const int k0 = k00 + k01;
804
1055
 
805
1056
  dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
@@ -810,7 +1061,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
810
1061
  #pragma unroll
811
1062
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
812
1063
  #pragma unroll
813
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1064
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
814
1065
  tile_B B;
815
1066
  float2 dsB[tile_C::ne/2];
816
1067
 
@@ -836,11 +1087,15 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
836
1087
  }
837
1088
  }
838
1089
  }
1090
+ #endif // defined(AMD_MFMA_AVAILABLE)
839
1091
  }
840
1092
 
841
- template <int mmq_x, int mmq_y, int nwarps>
1093
+ // Used for Q3_K, IQ2_S, and IQ2_XS
1094
+ template <int mmq_x, int mmq_y>
842
1095
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
843
1096
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1097
+ constexpr int nwarps = mmq_get_nwarps_device();
1098
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
844
1099
 
845
1100
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
846
1101
  const int * x_qs = (const int *) x;
@@ -849,7 +1104,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
849
1104
  const float * y_df = (const float *) y;
850
1105
 
851
1106
  // #pragma unroll
852
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
1107
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
853
1108
  const int k0 = k00 + k01;
854
1109
 
855
1110
  #pragma unroll
@@ -857,23 +1112,73 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
857
1112
  const int j = j0 + threadIdx.y;
858
1113
 
859
1114
  #pragma unroll
860
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1115
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
861
1116
  const int i = i0 + threadIdx.x;
862
1117
 
863
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
864
- &x_qs[i*(2*WARP_SIZE + 1) + k0],
1118
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
1119
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0],
865
1120
  &y_qs[j*MMQ_TILE_Y_K + k01],
866
- &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
1121
+ &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
867
1122
  y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
868
1123
  }
869
1124
  }
870
1125
  }
871
1126
  }
872
1127
 
873
- template <int mmq_x, int mmq_y, int nwarps>
1128
+ // Used for Q3_K, IQ2_S, and IQ2_XS:
1129
+ template <int mmq_x, int mmq_y>
874
1130
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
875
1131
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
876
- #ifdef NEW_MMA_AVAILABLE
1132
+ #if defined(AMD_MFMA_AVAILABLE)
1133
+ typedef tile<16, 8, int> tile_A;
1134
+ typedef tile<16, 8, int> tile_B;
1135
+ typedef tile<16, 16, int> tile_C;
1136
+ typedef tile<64, 2, int> tile_load;
1137
+
1138
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1139
+ constexpr int rows_per_warp = granularity;
1140
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1141
+
1142
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1143
+
1144
+ const int * x_qs = (const int *) x;
1145
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1146
+ const int * y_qs = (const int *) y + 4;
1147
+ const float * y_df = (const float *) y;
1148
+
1149
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1150
+
1151
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1152
+ const int k0 = k00 + k01;
1153
+
1154
+ tile_A A[ntx];
1155
+ #pragma unroll
1156
+ for (int n = 0; n < ntx; ++n) {
1157
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1158
+ }
1159
+
1160
+ #pragma unroll
1161
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1162
+ tile_B B[1];
1163
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1164
+
1165
+ const int j = j0 + tile_C::get_j(0);
1166
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
1167
+
1168
+ #pragma unroll
1169
+ for (int n = 0; n < ntx; ++n) {
1170
+ tile_C C;
1171
+ mma(C, A[n], B[0]);
1172
+
1173
+ #pragma unroll
1174
+ for (int l = 0; l < tile_C::ne; ++l) {
1175
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1176
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
1177
+ }
1178
+ }
1179
+ }
1180
+ }
1181
+ #elif defined(TURING_MMA_AVAILABLE)
877
1182
 
878
1183
  typedef tile<16, 4, int> tile_A;
879
1184
  typedef tile<16, 8, int> tile_A_8;
@@ -884,10 +1189,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
884
1189
  constexpr int rows_per_warp = 2 * granularity;
885
1190
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
886
1191
 
887
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
1192
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
888
1193
 
889
1194
  const int * x_qs = (const int *) x;
890
- const float * x_df = (const float *) x_qs + WARP_SIZE*2;
1195
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
891
1196
  const int * y_qs = (const int *) y + 4;
892
1197
  const float * y_df = (const float *) y;
893
1198
 
@@ -899,7 +1204,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
899
1204
  #pragma unroll
900
1205
  for (int n = 0; n < ntx; ++n) {
901
1206
  #pragma unroll
902
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1207
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
903
1208
  const int k0 = k00 + k01;
904
1209
 
905
1210
  load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
@@ -910,7 +1215,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
910
1215
  const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
911
1216
 
912
1217
  #pragma unroll
913
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
1218
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
914
1219
  const int k0 = k00 + k01;
915
1220
 
916
1221
  dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
@@ -921,7 +1226,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
921
1226
  #pragma unroll
922
1227
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
923
1228
  #pragma unroll
924
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1229
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
925
1230
  tile_B B[2];
926
1231
  float dB[tile_C::ne/2];
927
1232
 
@@ -950,28 +1255,31 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
950
1255
  }
951
1256
  }
952
1257
  #else
953
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
1258
+ GGML_UNUSED_VARS(x, y, sum, k00);
954
1259
  NO_DEVICE_CODE;
955
- #endif // NEW_MMA_AVAILABLE
1260
+ #endif // AMD_MFMA_AVAILABLE
956
1261
  }
957
1262
 
958
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
1263
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
959
1264
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1265
+ constexpr int nwarps = mmq_get_nwarps_device();
960
1266
 
961
- #ifdef NEW_MMA_AVAILABLE
1267
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
962
1268
  int * x_qs = (int *) x_tile;
963
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
1269
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
964
1270
  #else
965
1271
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
966
1272
  int * x_qs = (int *) x_tile;
967
1273
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
968
- #endif // NEW_MMA_AVAILABLE
1274
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
969
1275
 
970
- const int kqsx = threadIdx.x % QI2_K;
1276
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
1277
+ constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
1278
+ const int kqsx = threadIdx.x % threads_per_row;
971
1279
 
972
1280
  #pragma unroll
973
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) {
974
- int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K;
1281
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1282
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
975
1283
 
976
1284
  if (need_check) {
977
1285
  i = min(i, i_max);
@@ -987,11 +1295,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
987
1295
 
988
1296
  const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
989
1297
 
990
- #ifdef NEW_MMA_AVAILABLE
1298
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
991
1299
  x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
992
1300
  #else
993
- x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
994
- #endif // NEW_MMA_AVAILABLE
1301
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1302
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
995
1303
  }
996
1304
 
997
1305
  const int sc_m = bxi->scales[kqsx];
@@ -1002,17 +1310,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1002
1310
  const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
1003
1311
  #endif // FAST_FP16_AVAILABLE
1004
1312
 
1005
- #ifdef NEW_MMA_AVAILABLE
1313
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1006
1314
  x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
1007
1315
  #else
1008
- x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik;
1009
- #endif // NEW_MMA_AVAILABLE
1316
+ x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
1317
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1010
1318
  }
1011
1319
  }
1012
1320
 
1013
- template <int mmq_x, int mmq_y, int nwarps>
1321
+ template <int mmq_x, int mmq_y>
1014
1322
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1015
1323
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1324
+ constexpr int nwarps = mmq_get_nwarps_device();
1325
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1016
1326
 
1017
1327
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1018
1328
  const int * x_qs = (const int *) x;
@@ -1029,7 +1339,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1029
1339
  }
1030
1340
 
1031
1341
  #pragma unroll
1032
- for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1342
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1033
1343
  const int k0 = k00 + k01;
1034
1344
 
1035
1345
  #pragma unroll
@@ -1037,13 +1347,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1037
1347
  const int j = j0 + threadIdx.y;
1038
1348
 
1039
1349
  #pragma unroll
1040
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1350
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1041
1351
  const int i = i0 + threadIdx.x;
1042
1352
 
1043
1353
  constexpr int ns = 2;
1044
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1045
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1046
- &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1354
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1355
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1356
+ &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1047
1357
  &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1048
1358
  }
1049
1359
  }
@@ -1052,7 +1362,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1052
1362
  // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
1053
1363
  // As a workaround 2 separate loops are used instead.
1054
1364
  #pragma unroll
1055
- for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1365
+ for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1056
1366
  const int k0 = k00 + k01;
1057
1367
 
1058
1368
  #pragma unroll
@@ -1060,23 +1370,89 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1060
1370
  const int j = j0 + threadIdx.y;
1061
1371
 
1062
1372
  #pragma unroll
1063
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1373
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1064
1374
  const int i = i0 + threadIdx.x;
1065
1375
 
1066
1376
  constexpr int ns = 1;
1067
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1068
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1069
- &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1377
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1378
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1379
+ &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1070
1380
  &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1071
1381
  }
1072
1382
  }
1073
1383
  }
1074
1384
  }
1075
1385
 
1076
- template <int mmq_x, int mmq_y, int nwarps>
1386
+ template <int mmq_x, int mmq_y>
1077
1387
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1078
1388
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1079
- #ifdef NEW_MMA_AVAILABLE
1389
+ #if defined(AMD_MFMA_AVAILABLE)
1390
+ typedef tile<16, 8, int> tile_A;
1391
+ typedef tile<16, 8, int> tile_B;
1392
+ typedef tile<16, 16, int> tile_C;
1393
+ typedef tile<64, 2, int> tile_load;
1394
+
1395
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1396
+ constexpr int rows_per_warp = granularity;
1397
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1398
+
1399
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1400
+
1401
+ const int * x_qs = (const int *) x;
1402
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1403
+ const int * y_qs = (const int *) y + 4;
1404
+ const half2 * y_ds = (const half2 *) y;
1405
+
1406
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1407
+
1408
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1409
+ const int k0 = k00 + k01;
1410
+
1411
+ tile_A A[ntx];
1412
+ #pragma unroll
1413
+ for (int n = 0; n < ntx; ++n) {
1414
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1415
+ }
1416
+
1417
+ #pragma unroll
1418
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1419
+ tile_B B[1];
1420
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1421
+
1422
+ const int j = j0 + tile_C::get_j(0);
1423
+ const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
1424
+ const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
1425
+ : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1426
+ : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
1427
+
1428
+ tile_C Cm;
1429
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1430
+ tile_A A1;
1431
+ A1.x[0] = 0x01010101;
1432
+ A1.x[1] = 0x01010101;
1433
+ mma(Cm, A1, B[0]);
1434
+ }
1435
+
1436
+ #pragma unroll
1437
+ for (int n = 0; n < ntx; ++n) {
1438
+ tile_C Cd;
1439
+ mma(Cd, A[n], B[0]);
1440
+
1441
+ #pragma unroll
1442
+ for (int l = 0; l < tile_C::ne; ++l) {
1443
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1444
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
1445
+ float tmp = Cd.x[l]*dm.x;
1446
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1447
+ tmp -= Cm.x[l]*dm.y;
1448
+ }
1449
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
1450
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
1451
+ }
1452
+ }
1453
+ }
1454
+ }
1455
+ #elif defined(TURING_MMA_AVAILABLE)
1080
1456
 
1081
1457
  typedef tile<16, 4, int> tile_A;
1082
1458
  typedef tile<16, 8, int> tile_A_8;
@@ -1087,10 +1463,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1087
1463
  constexpr int rows_per_warp = 2 * granularity;
1088
1464
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1089
1465
 
1090
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
1466
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1091
1467
 
1092
1468
  const int * x_qs = (const int *) x;
1093
- const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
1469
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1094
1470
  const int * y_qs = (const int *) y + 4;
1095
1471
  const half2 * y_ds = (const half2 *) y;
1096
1472
 
@@ -1103,7 +1479,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1103
1479
  #pragma unroll
1104
1480
  for (int n = 0; n < ntx; ++n) {
1105
1481
  #pragma unroll
1106
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1482
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1107
1483
  const int k0 = k00 + k01;
1108
1484
 
1109
1485
  load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
@@ -1117,7 +1493,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1117
1493
  const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
1118
1494
 
1119
1495
  #pragma unroll
1120
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
1496
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) {
1121
1497
  const int k0 = k00 + k01;
1122
1498
 
1123
1499
  const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
@@ -1140,7 +1516,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1140
1516
  }
1141
1517
 
1142
1518
  #pragma unroll
1143
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1519
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1144
1520
  tile_B B[2];
1145
1521
 
1146
1522
  // Here load_generic is faster than load_ldmatrix.
@@ -1148,7 +1524,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1148
1524
  load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
1149
1525
 
1150
1526
  tile_C Cm[2];
1151
- if (k01 >= WARP_SIZE * 3/4) {
1527
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1152
1528
  tile_A A1;
1153
1529
  A1.x[0] = 0x01010101;
1154
1530
  A1.x[1] = 0x01010101;
@@ -1166,16 +1542,16 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1166
1542
  #pragma unroll
1167
1543
  for (int l = 0; l < tile_C::ne; ++l) {
1168
1544
  float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
1169
- if (k01 >= WARP_SIZE * 3/4) {
1545
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1170
1546
  tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
1171
1547
  }
1172
- sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
1548
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y);
1173
1549
  }
1174
1550
  }
1175
1551
  }
1176
1552
 
1177
1553
  #pragma unroll
1178
- for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
1554
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) {
1179
1555
  float2 sB[tile_C::ne/2];
1180
1556
 
1181
1557
  #pragma unroll
@@ -1196,29 +1572,33 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1196
1572
  }
1197
1573
  }
1198
1574
  #else
1199
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
1575
+ GGML_UNUSED_VARS(x, y, sum, k00);
1200
1576
  NO_DEVICE_CODE;
1201
- #endif // NEW_MMA_AVAILABLE
1577
+ #endif // AMD_MFMA_AVAILABLE
1202
1578
  }
1203
1579
 
1204
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
1580
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
1205
1581
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1582
+ constexpr int nwarps = mmq_get_nwarps_device();
1583
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1206
1584
 
1207
- #ifdef NEW_MMA_AVAILABLE
1585
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1208
1586
  int * x_qs = (int *) x_tile;
1209
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
1587
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1210
1588
  #else
1211
1589
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1212
1590
  int * x_qs = (int *) x_tile;
1213
1591
  float * x_df = (float *) (x_qs + txs.qs);
1214
1592
  int * x_sc = (int *) (x_df + txs.dm);
1215
- #endif // NEW_MMA_AVAILABLE
1593
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1216
1594
 
1217
- const int kqsx = threadIdx.x % QI3_K;
1595
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
1596
+ constexpr int nrows = warp_size / threads_per_row;
1597
+ const int kqsx = threadIdx.x % threads_per_row;
1218
1598
 
1219
1599
  #pragma unroll
1220
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) {
1221
- int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K;
1600
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1601
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1222
1602
 
1223
1603
  if (need_check) {
1224
1604
  i = min(i, i_max);
@@ -1238,17 +1618,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1238
1618
 
1239
1619
  const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
1240
1620
 
1241
- #ifdef NEW_MMA_AVAILABLE
1621
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1242
1622
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
1243
1623
  #else
1244
- x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
1245
- #endif // NEW_MMA_AVAILABLE
1624
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1625
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1246
1626
  }
1247
1627
  }
1248
1628
 
1629
+ constexpr int rows_per_warp = warp_size / 4;
1249
1630
  #pragma unroll
1250
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
1251
- int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
1631
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1632
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
1252
1633
 
1253
1634
  if (need_check) {
1254
1635
  i = min(i, i_max);
@@ -1256,7 +1637,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1256
1637
 
1257
1638
  const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1258
1639
 
1259
- const int ksc = threadIdx.x % (WARP_SIZE/8);
1640
+ const int ksc = threadIdx.x % 4;
1260
1641
 
1261
1642
  const int ksc_low = ksc % (QI3_K/8);
1262
1643
  const int shift_low = 4 * (ksc / (QI3_K/8));
@@ -1268,23 +1649,23 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1268
1649
 
1269
1650
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1270
1651
 
1271
- #ifdef NEW_MMA_AVAILABLE
1652
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1272
1653
  const int8_t * sc8 = (const int8_t *) &sc;
1273
1654
  const float d = bxi->d;
1274
1655
 
1275
1656
  #pragma unroll
1276
1657
  for (int l = 0; l < int(sizeof(int)); ++l) {
1277
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
1658
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
1278
1659
  }
1279
1660
  #else
1280
- x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
1281
- #endif // NEW_MMA_AVAILABLE
1661
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
1662
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1282
1663
  }
1283
1664
 
1284
- #ifndef NEW_MMA_AVAILABLE
1665
+ #if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1285
1666
  #pragma unroll
1286
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
1287
- int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
1667
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1668
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1288
1669
 
1289
1670
  if (need_check) {
1290
1671
  i = min(i, i_max);
@@ -1294,12 +1675,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1294
1675
 
1295
1676
  x_df[i] = bxi->d;
1296
1677
  }
1297
- #endif // NEW_MMA_AVAILABLE
1678
+ #endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1298
1679
  }
1299
1680
 
1300
- template <int mmq_x, int mmq_y, int nwarps>
1681
+ template <int mmq_x, int mmq_y>
1301
1682
  static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1302
1683
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1684
+ constexpr int nwarps = mmq_get_nwarps_device();
1685
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1303
1686
 
1304
1687
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1305
1688
  const int * x_qs = (const int *) x;
@@ -1309,7 +1692,7 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1309
1692
  const float * y_df = (const float *) y;
1310
1693
 
1311
1694
  // #pragma unroll
1312
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1695
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1313
1696
  const int k0 = k00 + k01;
1314
1697
 
1315
1698
  #pragma unroll
@@ -1317,13 +1700,13 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1317
1700
  const int j = j0 + threadIdx.y;
1318
1701
 
1319
1702
  #pragma unroll
1320
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1703
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1321
1704
  const int i = i0 + threadIdx.x;
1322
1705
 
1323
- const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4;
1706
+ const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4;
1324
1707
 
1325
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
1326
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
1708
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq(
1709
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
1327
1710
  x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1328
1711
  }
1329
1712
  }
@@ -1340,72 +1723,85 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
1340
1723
  ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
1341
1724
  }
1342
1725
 
1343
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1726
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1344
1727
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1728
+ constexpr int nwarps = mmq_get_nwarps_device();
1729
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1345
1730
 
1346
- #ifdef NEW_MMA_AVAILABLE
1731
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1347
1732
  int * x_qs = (int *) x_tile;
1348
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
1733
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1349
1734
  #else
1350
1735
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1351
1736
  int * x_qs = (int *) x_tile;
1352
1737
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1353
1738
  int * x_sc = (int *) (x_dm + txs.dm);
1354
- #endif // NEW_MMA_AVAILABLE
1739
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1740
+
1741
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
1742
+ constexpr int nrows = warp_size / threads_per_row;
1743
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1355
1744
 
1356
1745
  #pragma unroll
1357
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1358
- int i = i0 + threadIdx.y;
1746
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1747
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1359
1748
 
1360
1749
  if (need_check) {
1361
1750
  i = min(i, i_max);
1362
1751
  }
1363
1752
 
1364
1753
  const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1365
- const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
1754
+ const int qs0 = get_int_b4(bxi->qs, txi);
1366
1755
 
1367
- #ifdef NEW_MMA_AVAILABLE
1368
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
1369
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
1756
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1757
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
1758
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
1370
1759
  #else
1371
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
1372
- #endif // NEW_MMA_AVAILABLE
1760
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
1761
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1373
1762
  }
1374
1763
 
1375
- #ifdef NEW_MMA_AVAILABLE
1376
-
1764
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1765
+ constexpr int rows_per_warp = warp_size / 2;
1377
1766
  #pragma unroll
1378
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
1379
- int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
1380
-
1381
- if (need_check) {
1382
- i = min(i, i_max);
1383
- }
1767
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1768
+ #if defined(AMD_MFMA_AVAILABLE)
1769
+ // Need if on AMD instead of % because warp_size == 64
1770
+ // This causes double work and throughput loss (MI300X)
1771
+ // H100 loses about 100 t/s with 'if' condition over '%'
1772
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
1773
+ if (i < mmq_y) {
1774
+ #else
1775
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
1776
+ {
1777
+ #endif // defined(AMD_MFMA_AVAILABLE)
1778
+ if (need_check) {
1779
+ i = min(i, i_max);
1780
+ }
1384
1781
 
1385
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1782
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1386
1783
 
1387
- const int * scales = (const int *) bxi->scales;
1388
- const int ksc = threadIdx.x % (WARP_SIZE/16);
1784
+ const int * scales = (const int *) bxi->scales;
1785
+ const int ksc = threadIdx.x % 2;
1389
1786
 
1390
- const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1391
- const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1787
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1788
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1392
1789
 
1393
- const uint8_t * sc8 = (const uint8_t *) &sc32;
1394
- const uint8_t * m8 = (const uint8_t *) &m32;
1790
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
1791
+ const uint8_t * m8 = (const uint8_t *) &m32;
1395
1792
 
1396
- const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1793
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1397
1794
 
1398
- #pragma unroll
1399
- for (int l = 0; l < int(sizeof(int)); ++l) {
1400
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
1795
+ #pragma unroll
1796
+ for (int l = 0; l < sizeof(int); ++l) {
1797
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
1798
+ }
1401
1799
  }
1402
1800
  }
1403
-
1404
1801
  #else
1405
-
1406
1802
  #pragma unroll
1407
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) {
1408
- int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y;
1803
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1804
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1409
1805
 
1410
1806
  if (need_check) {
1411
1807
  i = min(i, i_max);
@@ -1415,30 +1811,32 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1415
1811
 
1416
1812
  x_dm[i] = bxi->dm;
1417
1813
  }
1418
-
1814
+ constexpr int rows_per_warp = warp_size / 4;
1419
1815
  #pragma unroll
1420
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
1421
- int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
1816
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1817
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
1422
1818
 
1423
1819
  if (need_check) {
1424
1820
  i = min(i, i_max);
1425
1821
  }
1426
1822
 
1427
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
1823
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8);
1428
1824
 
1429
1825
  const int * scales = (const int *) bxi->scales;
1430
1826
 
1431
- const int ksc = threadIdx.x % (WARP_SIZE/8);
1827
+ const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
1432
1828
  const int scales8 = unpack_scales_q45_K(scales, ksc);
1433
1829
 
1434
- x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
1830
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1435
1831
  }
1436
- #endif // NEW_MMA_AVAILABLE
1832
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1437
1833
  }
1438
1834
 
1439
- template <int mmq_x, int mmq_y, int nwarps>
1835
+ template <int mmq_x, int mmq_y>
1440
1836
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1441
1837
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1838
+ constexpr int nwarps = mmq_get_nwarps_device();
1839
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1442
1840
 
1443
1841
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1444
1842
  const int * x_qs = (const int *) x;
@@ -1448,7 +1846,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1448
1846
  const half2 * y_ds = (const half2 *) y;
1449
1847
 
1450
1848
  // #pragma unroll
1451
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
1849
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
1452
1850
  const int k0 = k00 + k01;
1453
1851
 
1454
1852
  #pragma unroll
@@ -1456,97 +1854,110 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1456
1854
  const int j = j0 + threadIdx.y;
1457
1855
 
1458
1856
  #pragma unroll
1459
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1857
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1460
1858
  const int i = i0 + threadIdx.x;
1461
1859
 
1462
- const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16);
1860
+ const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16);
1463
1861
 
1464
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
1465
- &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1862
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq(
1863
+ &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1466
1864
  x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1467
1865
  }
1468
1866
  }
1469
1867
  }
1470
1868
  }
1471
1869
 
1472
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1870
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1473
1871
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1872
+ constexpr int nwarps = mmq_get_nwarps_device();
1873
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1474
1874
 
1475
- #ifdef NEW_MMA_AVAILABLE
1875
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1476
1876
  int * x_qs = (int *) x_tile;
1477
- half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
1877
+ half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
1478
1878
  #else
1479
1879
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1480
1880
  int * x_qs = (int *) x_tile;
1481
1881
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1482
1882
  int * x_sc = (int *) (x_dm + txs.dm);
1483
- #endif // NEW_MMA_AVAILABLE
1883
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1884
+
1885
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
1886
+ constexpr int nrows = warp_size / threads_per_row;
1887
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1484
1888
 
1485
1889
  #pragma unroll
1486
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1487
- int i = i0 + threadIdx.y;
1890
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1891
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1488
1892
 
1489
1893
  if (need_check) {
1490
1894
  i = min(i, i_max);
1491
1895
  }
1492
1896
 
1493
1897
  const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1494
- const int ky = QR5_K*threadIdx.x;
1898
+ const int ky = QR5_K*txi;
1495
1899
 
1496
- const int ql = get_int_b4(bxi->qs, threadIdx.x);
1900
+ const int ql = get_int_b4(bxi->qs, txi);
1497
1901
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1498
1902
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1499
1903
 
1500
- const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4));
1501
- const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010;
1502
- const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010;
1904
+ const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4));
1905
+ const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010;
1906
+ const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010;
1503
1907
 
1504
- const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
1505
- const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
1908
+ const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
1909
+ const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
1506
1910
 
1507
- #ifdef NEW_MMA_AVAILABLE
1911
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1508
1912
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
1509
1913
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
1510
1914
  #else
1511
- x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
1512
- x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
1513
- #endif // NEW_MMA_AVAILABLE
1915
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
1916
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
1917
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1514
1918
  }
1515
1919
 
1516
- #ifdef NEW_MMA_AVAILABLE
1517
-
1920
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1921
+ constexpr int rows_per_warp = warp_size / 2;
1518
1922
  #pragma unroll
1519
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
1520
- int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
1521
-
1522
- if (need_check) {
1523
- i = min(i, i_max);
1524
- }
1923
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1924
+ #if defined(AMD_MFMA_AVAILABLE)
1925
+ // Need if on AMD instead of % because warp_size == 64
1926
+ // This causes double work and throughput loss (MI300X)
1927
+ // H100 loses about 100 t/s with 'if' condition over '%'
1928
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
1929
+ if (i < mmq_y) {
1930
+ #else
1931
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
1932
+ {
1933
+ #endif // defined(AMD_MFMA_AVAILABLE)
1934
+ if (need_check) {
1935
+ i = min(i, i_max);
1936
+ }
1525
1937
 
1526
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1938
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1527
1939
 
1528
- const int * scales = (const int *) bxi->scales;
1529
- const int ksc = threadIdx.x % (WARP_SIZE/16);
1940
+ const int * scales = (const int *) bxi->scales;
1941
+ const int ksc = threadIdx.x % 2;
1530
1942
 
1531
- const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1532
- const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1943
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1944
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1533
1945
 
1534
- const uint8_t * sc8 = (const uint8_t *) &sc32;
1535
- const uint8_t * m8 = (const uint8_t *) &m32;
1946
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
1947
+ const uint8_t * m8 = (const uint8_t *) &m32;
1536
1948
 
1537
- const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1949
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1538
1950
 
1539
1951
  #pragma unroll
1540
- for (int l = 0; l < int(sizeof(int)); ++l) {
1541
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
1952
+ for (int l = 0; l < int(sizeof(int)); ++l) {
1953
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
1954
+ }
1542
1955
  }
1543
1956
  }
1544
-
1545
1957
  #else
1546
-
1547
1958
  #pragma unroll
1548
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) {
1549
- int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y;
1959
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1960
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1550
1961
 
1551
1962
  if (need_check) {
1552
1963
  i = min(i, i_max);
@@ -1557,9 +1968,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1557
1968
  x_dm[i] = bxi->dm;
1558
1969
  }
1559
1970
 
1971
+ constexpr int rows_per_warp = warp_size / 4;
1560
1972
  #pragma unroll
1561
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
1562
- int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y;
1973
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1974
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
1563
1975
 
1564
1976
  if (need_check) {
1565
1977
  i = min(i, i_max);
@@ -1569,17 +1981,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1569
1981
 
1570
1982
  const int * scales = (const int *) bxi->scales;
1571
1983
 
1572
- const int ksc = threadIdx.x % (WARP_SIZE/8);
1984
+ const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
1573
1985
  const int scales8 = unpack_scales_q45_K(scales, ksc);
1574
1986
 
1575
- x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
1987
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1576
1988
  }
1577
- #endif // NEW_MMA_AVAILABLE
1989
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1578
1990
  }
1579
1991
 
1580
- template <int mmq_x, int mmq_y, int nwarps>
1992
+ template <int mmq_x, int mmq_y>
1581
1993
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1582
1994
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1995
+ constexpr int nwarps = mmq_get_nwarps_device();
1996
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1583
1997
 
1584
1998
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1585
1999
  const int * x_qs = (const int *) x;
@@ -1589,7 +2003,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1589
2003
  const half2 * y_ds = (const half2 *) y;
1590
2004
 
1591
2005
  // #pragma unroll
1592
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
2006
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
1593
2007
  const int k0 = k00 + k01;
1594
2008
 
1595
2009
  #pragma unroll
@@ -1597,36 +2011,42 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1597
2011
  const int j = j0 + threadIdx.y;
1598
2012
 
1599
2013
  #pragma unroll
1600
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
2014
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1601
2015
  const int i = i0 + threadIdx.x;
1602
2016
 
1603
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
2017
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16);
1604
2018
 
1605
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
1606
- &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
2019
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq(
2020
+ &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1607
2021
  x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1608
2022
  }
1609
2023
  }
1610
2024
  }
1611
2025
  }
1612
2026
 
1613
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
2027
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
1614
2028
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2029
+ constexpr int nwarps = mmq_get_nwarps_device();
2030
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1615
2031
 
1616
- #ifdef NEW_MMA_AVAILABLE
2032
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1617
2033
  int * x_qs = (int *) x_tile;
1618
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
1619
- int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
2034
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2035
+ int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
1620
2036
  #else
1621
2037
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
1622
2038
  int * x_qs = (int *) x_tile;
1623
2039
  float * x_df = (float *) (x_qs + txs.qs);
1624
2040
  int * x_sc = (int *) (x_df + txs.dm);
1625
- #endif // NEW_MMA_AVAILABLE
2041
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2042
+
2043
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
2044
+ constexpr int nrows = warp_size / threads_per_row;
2045
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1626
2046
 
1627
2047
  #pragma unroll
1628
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1629
- int i = i0 + threadIdx.y;
2048
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2049
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1630
2050
 
1631
2051
  if (need_check) {
1632
2052
  i = min(i, i_max);
@@ -1634,67 +2054,67 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1634
2054
 
1635
2055
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
1636
2056
 
1637
- const int ql = get_int_b2(bxi->ql, threadIdx.x);
2057
+ const int ql = get_int_b2(bxi->ql, txi);
1638
2058
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1639
2059
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1640
2060
 
1641
- const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4));
1642
- const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030;
1643
- const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030;
2061
+ const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4));
2062
+ const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030;
2063
+ const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030;
1644
2064
 
1645
- const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
1646
- const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
2065
+ const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
2066
+ const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
1647
2067
 
1648
- #ifdef NEW_MMA_AVAILABLE
2068
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1649
2069
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
1650
2070
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
1651
2071
  #else
1652
- x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
1653
- x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
1654
- #endif // NEW_MMA_AVAILABLE
2072
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2073
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2074
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1655
2075
  }
1656
2076
 
1657
- const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
1658
- const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
1659
-
1660
2077
  #pragma unroll
1661
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
1662
- int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
2078
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
2079
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1663
2080
 
1664
2081
  if (need_check) {
1665
2082
  i = min(i, i_max);
1666
2083
  }
1667
2084
 
1668
- const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
2085
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
1669
2086
 
1670
- #ifdef NEW_MMA_AVAILABLE
1671
- x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d;
2087
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2088
+ x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
1672
2089
  #else
1673
- x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
1674
- #endif // NEW_MMA_AVAILABLE
2090
+ x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
2091
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1675
2092
  }
1676
2093
 
2094
+ constexpr int rows_per_warp = warp_size / 4;
1677
2095
  #pragma unroll
1678
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
1679
- int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
2096
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2097
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
1680
2098
 
1681
2099
  if (need_check) {
1682
2100
  i = min(i, i_max);
1683
2101
  }
1684
2102
 
1685
- const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
2103
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
1686
2104
 
1687
- #ifdef NEW_MMA_AVAILABLE
1688
- x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
2105
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2106
+ x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
1689
2107
  #else
1690
- x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
1691
- #endif // NEW_MMA_AVAILABLE
2108
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
2109
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1692
2110
  }
1693
2111
  }
1694
2112
 
1695
- template <int mmq_x, int mmq_y, int nwarps>
2113
+ template <int mmq_x, int mmq_y>
1696
2114
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1697
2115
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2116
+ constexpr int nwarps = mmq_get_nwarps_device();
2117
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1698
2118
 
1699
2119
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
1700
2120
  const int * x_qs = (const int *) x;
@@ -1704,7 +2124,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1704
2124
  const float * y_df = (const float *) y;
1705
2125
 
1706
2126
  // #pragma unroll
1707
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
2127
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
1708
2128
  const int k0 = k00 + k01;
1709
2129
 
1710
2130
  #pragma unroll
@@ -1712,23 +2132,74 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1712
2132
  const int j = j0 + threadIdx.y;
1713
2133
 
1714
2134
  #pragma unroll
1715
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
2135
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1716
2136
  const int i = i0 + threadIdx.x;
1717
2137
 
1718
- const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]);
2138
+ const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]);
1719
2139
 
1720
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
1721
- &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
1722
- x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
2140
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq(
2141
+ &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
2142
+ x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1723
2143
  }
1724
2144
  }
1725
2145
  }
1726
2146
  }
1727
2147
 
1728
- template <int mmq_x, int mmq_y, int nwarps>
2148
+ template <int mmq_x, int mmq_y>
1729
2149
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1730
2150
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1731
- #ifdef NEW_MMA_AVAILABLE
2151
+ #if defined(AMD_MFMA_AVAILABLE)
2152
+ typedef tile<16, 8, int> tile_A;
2153
+ typedef tile<16, 8, int> tile_B;
2154
+ typedef tile<16, 16, int> tile_C;
2155
+ typedef tile<64, 2, int> tile_load;
2156
+
2157
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
2158
+ constexpr int rows_per_warp = granularity;
2159
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2160
+
2161
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2162
+
2163
+ const int * x_qs = (const int *) x;
2164
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2165
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2166
+ const int * y_qs = (const int *) y + 4;
2167
+ const float * y_df = (const float *) y;
2168
+
2169
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
2170
+
2171
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2172
+ const int k0 = k00 + k01;
2173
+
2174
+ tile_A A[ntx];
2175
+ #pragma unroll
2176
+ for (int n = 0; n < ntx; ++n) {
2177
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2178
+ }
2179
+
2180
+ #pragma unroll
2181
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2182
+ tile_B B[1];
2183
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2184
+
2185
+ const int j = j0 + tile_C::get_j(0);
2186
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
2187
+
2188
+ #pragma unroll
2189
+ for (int n = 0; n < ntx; ++n) {
2190
+ tile_C C;
2191
+ mma(C, A[n], B[0]);
2192
+
2193
+ #pragma unroll
2194
+ for (int l = 0; l < tile_C::ne; ++l) {
2195
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2196
+ const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2197
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2198
+ }
2199
+ }
2200
+ }
2201
+ }
2202
+ #elif defined(TURING_MMA_AVAILABLE)
1732
2203
 
1733
2204
  typedef tile<16, 4, int> tile_A;
1734
2205
  typedef tile< 8, 4, int> tile_B;
@@ -1738,11 +2209,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1738
2209
  constexpr int rows_per_warp = 2 * granularity;
1739
2210
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1740
2211
 
1741
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
2212
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1742
2213
 
1743
2214
  const int * x_qs = (const int *) x;
1744
- const float * x_df = (const float *) x_qs + WARP_SIZE*2;
1745
- const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K;
2215
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2216
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
1746
2217
  const int * y_qs = (const int *) y + 4;
1747
2218
  const float * y_df = (const float *) y;
1748
2219
 
@@ -1755,7 +2226,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1755
2226
  #pragma unroll
1756
2227
  for (int n = 0; n < ntx; ++n) {
1757
2228
  #pragma unroll
1758
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
2229
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
1759
2230
  const int k0 = k00 + k01;
1760
2231
 
1761
2232
  load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
@@ -1763,7 +2234,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1763
2234
  }
1764
2235
 
1765
2236
  #pragma unroll
1766
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
2237
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) {
1767
2238
  const int k0 = k00 + k01;
1768
2239
 
1769
2240
  #pragma unroll
@@ -1793,7 +2264,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1793
2264
  float tmp[ntx][tile_C::ne] = {{0.0f}};
1794
2265
 
1795
2266
  #pragma unroll
1796
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
2267
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
1797
2268
  tile_B B[2];
1798
2269
  float dB[tile_C::ne/2];
1799
2270
 
@@ -1830,29 +2301,34 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1830
2301
  }
1831
2302
  }
1832
2303
  #else
1833
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
2304
+ GGML_UNUSED_VARS(x, y, sum, k00);
1834
2305
  NO_DEVICE_CODE;
1835
- #endif // NEW_MMA_AVAILABLE
2306
+ #endif // AMD_MFMA_AVAILABLE
1836
2307
  }
1837
2308
 
1838
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
2309
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
1839
2310
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2311
+ constexpr int nwarps = mmq_get_nwarps_device();
2312
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1840
2313
 
1841
- #ifdef NEW_MMA_AVAILABLE
2314
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1842
2315
  int * x_qs = (int *) x_tile;
1843
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2316
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1844
2317
  #else
1845
2318
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
1846
2319
  int * x_qs = (int *) x_tile;
1847
2320
  float * x_df = (float *) (x_qs + txs.qs);
1848
- #endif // NEW_MMA_AVAILABLE
2321
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1849
2322
 
1850
- const int kbx = threadIdx.x / QI4_NL;
1851
- const int kqsx = threadIdx.x % QI4_NL;
2323
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
2324
+ constexpr int nrows = warp_size / threads_per_row;
2325
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2326
+ const int kbx = txi / QI4_NL;
2327
+ const int kqsx = txi % QI4_NL;
1852
2328
 
1853
2329
  #pragma unroll
1854
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1855
- int i = i0 + threadIdx.y;
2330
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2331
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1856
2332
 
1857
2333
  if (need_check) {
1858
2334
  i = min(i, i_max);
@@ -1861,23 +2337,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1861
2337
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
1862
2338
 
1863
2339
  const int aux_q4 = get_int_b2(bxi->qs, kqsx);
1864
- const int2 v = get_int_from_table_16(aux_q4);
1865
- const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
1866
- #ifdef NEW_MMA_AVAILABLE
1867
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
1868
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2340
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2341
+ const int k0 = kbx * (2 * QI4_NL) + kqsx;
2342
+
2343
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2344
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2345
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
1869
2346
  #else
1870
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
1871
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
1872
- #endif // NEW_MMA_AVAILABLE
2347
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2348
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
2349
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1873
2350
  }
1874
2351
 
1875
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
2352
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
2353
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
1876
2354
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
1877
2355
 
1878
2356
  #pragma unroll
1879
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
1880
- int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
2357
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
2358
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
1881
2359
 
1882
2360
  if (need_check) {
1883
2361
  i = min(i, i_max);
@@ -1885,31 +2363,35 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1885
2363
 
1886
2364
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
1887
2365
 
1888
- #ifdef NEW_MMA_AVAILABLE
1889
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
2366
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2367
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
1890
2368
  #else
1891
- x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
1892
- #endif // NEW_MMA_AVAILABLE
2369
+ x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
2370
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1893
2371
  }
1894
2372
  }
1895
2373
 
1896
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
2374
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
1897
2375
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2376
+ constexpr int nwarps = mmq_get_nwarps_device();
2377
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1898
2378
 
1899
- #ifdef NEW_MMA_AVAILABLE
2379
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1900
2380
  int * x_qs = (int *) x_tile;
1901
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2381
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1902
2382
  #else
1903
2383
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
1904
2384
  int * x_qs = (int *) x_tile;
1905
2385
  float * x_df = (float *) (x_qs + txs.qs);
1906
- #endif // NEW_MMA_AVAILABLE
2386
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1907
2387
 
1908
- const int kqsx = threadIdx.x % (QI2_XXS/2);
2388
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
2389
+ constexpr int nrows = warp_size / threads_per_row;
2390
+ const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1909
2391
 
1910
2392
  #pragma unroll
1911
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) {
1912
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2);
2393
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2394
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1913
2395
 
1914
2396
  if (need_check) {
1915
2397
  i = min(i, i_max);
@@ -1932,42 +2414,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1932
2414
  const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
1933
2415
  const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
1934
2416
 
1935
- #ifdef NEW_MMA_AVAILABLE
2417
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1936
2418
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
1937
2419
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
1938
2420
  #else
1939
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0;
1940
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1;
1941
- #endif // NEW_MMA_AVAILABLE
2421
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
2422
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
2423
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1942
2424
  }
1943
2425
 
1944
2426
  const int ls = aux32 >> 28;
1945
2427
  const float d = bxi->d;
1946
- #ifdef NEW_MMA_AVAILABLE
1947
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
2428
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2429
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
1948
2430
  #else
1949
- x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4;
1950
- #endif // NEW_MMA_AVAILABLE
2431
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
2432
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1951
2433
  }
1952
2434
  }
1953
2435
 
1954
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
2436
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
1955
2437
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2438
+ constexpr int nwarps = mmq_get_nwarps_device();
2439
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1956
2440
 
1957
- #ifdef NEW_MMA_AVAILABLE
2441
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1958
2442
  int * x_qs = (int *) x_tile;
1959
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2443
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1960
2444
  #else
1961
2445
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
1962
2446
  int * x_qs = (int *) x_tile;
1963
2447
  float * x_df = (float *) (x_qs + txs.qs);
1964
- #endif // NEW_MMA_AVAILABLE
2448
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1965
2449
 
1966
- const int kqsx = threadIdx.x % (QI2_XS/2);
2450
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
2451
+ constexpr int nrows = warp_size / threads_per_row;
2452
+ const int kqsx = threadIdx.x % threads_per_row;
1967
2453
 
1968
2454
  #pragma unroll
1969
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) {
1970
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2);
2455
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2456
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1971
2457
 
1972
2458
  if (need_check) {
1973
2459
  i = min(i, i_max);
@@ -1986,44 +2472,48 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1986
2472
  const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
1987
2473
  const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
1988
2474
 
1989
- #ifdef NEW_MMA_AVAILABLE
2475
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1990
2476
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
1991
2477
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
1992
2478
  #else
1993
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
1994
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
1995
- #endif // NEW_MMA_AVAILABLE
2479
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2480
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2481
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1996
2482
  }
1997
2483
 
1998
2484
  const int ls = bxi->scales[kqsx];
1999
2485
  const float d = bxi->d;
2000
- #ifdef NEW_MMA_AVAILABLE
2001
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2002
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2486
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2487
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2488
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2003
2489
  #else
2004
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2005
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2006
- #endif // NEW_MMA_AVAILABLE
2490
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2491
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2492
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2007
2493
  }
2008
2494
  }
2009
2495
 
2010
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
2496
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
2011
2497
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2498
+ constexpr int nwarps = mmq_get_nwarps_device();
2499
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2012
2500
 
2013
- #ifdef NEW_MMA_AVAILABLE
2501
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2014
2502
  int * x_qs = (int *) x_tile;
2015
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2503
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2016
2504
  #else
2017
2505
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
2018
2506
  int * x_qs = (int *) x_tile;
2019
2507
  float * x_df = (float *) (x_qs + txs.qs);
2020
- #endif // NEW_MMA_AVAILABLE
2508
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2021
2509
 
2022
- const int kqsx = threadIdx.x % (QI2_S/2);
2510
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
2511
+ constexpr int nrows = warp_size / threads_per_row;
2512
+ const int kqsx = threadIdx.x % threads_per_row;
2023
2513
 
2024
2514
  #pragma unroll
2025
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) {
2026
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2);
2515
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2516
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2027
2517
 
2028
2518
  if (need_check) {
2029
2519
  i = min(i, i_max);
@@ -2049,44 +2539,48 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2049
2539
  const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
2050
2540
  const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
2051
2541
 
2052
- #ifdef NEW_MMA_AVAILABLE
2542
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2053
2543
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2054
2544
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2055
2545
  #else
2056
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2057
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2058
- #endif // NEW_MMA_AVAILABLE
2546
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2547
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2548
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2059
2549
  }
2060
2550
 
2061
2551
  const int ls = bxi->scales[kqsx];
2062
2552
  const float d = bxi->d;
2063
- #ifdef NEW_MMA_AVAILABLE
2064
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2065
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2553
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2554
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2555
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2066
2556
  #else
2067
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2068
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2069
- #endif // NEW_MMA_AVAILABLE
2557
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2558
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2559
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2070
2560
  }
2071
2561
  }
2072
2562
 
2073
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2563
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2074
2564
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2565
+ constexpr int nwarps = mmq_get_nwarps_device();
2566
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2075
2567
 
2076
- #ifdef NEW_MMA_AVAILABLE
2568
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2077
2569
  int * x_qs = (int *) x_tile;
2078
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2570
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2079
2571
  #else
2080
2572
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
2081
2573
  int * x_qs = (int *) x_tile;
2082
2574
  float * x_df = (float *) (x_qs + txs.qs);
2083
- #endif // NEW_MMA_AVAILABLE
2575
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2084
2576
 
2085
- const int kqsx = threadIdx.x % (QI3_XXS/2);
2577
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
2578
+ constexpr int nrows = warp_size / threads_per_row;
2579
+ const int kqsx = threadIdx.x % threads_per_row;
2086
2580
 
2087
2581
  #pragma unroll
2088
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) {
2089
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2);
2582
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2583
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2090
2584
 
2091
2585
  if (need_check) {
2092
2586
  i = min(i, i_max);
@@ -2107,42 +2601,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2107
2601
  const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
2108
2602
  const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
2109
2603
 
2110
- #ifdef NEW_MMA_AVAILABLE
2604
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2111
2605
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
2112
2606
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
2113
2607
  #else
2114
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2115
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2116
- #endif // NEW_MMA_AVAILABLE
2608
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2609
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2610
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2117
2611
  }
2118
2612
 
2119
2613
  const int ls = aux32 >> 28;
2120
2614
  const float d = bxi->d;
2121
- #ifdef NEW_MMA_AVAILABLE
2122
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2615
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2616
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2123
2617
  #else
2124
- x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2125
- #endif // NEW_MMA_AVAILABLE
2618
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2619
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2126
2620
  }
2127
2621
  }
2128
2622
 
2129
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2623
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2130
2624
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2625
+ constexpr int nwarps = mmq_get_nwarps_device();
2626
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2131
2627
 
2132
- #ifdef NEW_MMA_AVAILABLE
2628
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2133
2629
  int * x_qs = (int *) x_tile;
2134
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2630
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2135
2631
  #else
2136
2632
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2137
2633
  int * x_qs = (int *) x_tile;
2138
2634
  float * x_df = (float *) (x_qs + txs.qs);
2139
- #endif // NEW_MMA_AVAILABLE
2635
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2140
2636
 
2141
- const int kqsx = threadIdx.x % (QI3_S/2);
2637
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
2638
+ constexpr int nrows = warp_size / threads_per_row;
2639
+ const int kqsx = threadIdx.x % threads_per_row;
2142
2640
 
2143
2641
  #pragma unroll
2144
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) {
2145
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2);
2642
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2643
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2146
2644
 
2147
2645
  if (need_check) {
2148
2646
  i = min(i, i_max);
@@ -2170,42 +2668,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2170
2668
  const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2171
2669
  const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2172
2670
 
2173
- #ifdef NEW_MMA_AVAILABLE
2671
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2174
2672
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
2175
2673
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
2176
2674
  #else
2177
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l;
2178
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h;
2179
- #endif // NEW_MMA_AVAILABLE
2675
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
2676
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
2677
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2180
2678
  }
2181
2679
 
2182
2680
  const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
2183
2681
  const float d = bxi->d;
2184
- #ifdef NEW_MMA_AVAILABLE
2185
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2682
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2683
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2186
2684
  #else
2187
- x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d;
2188
- #endif // NEW_MMA_AVAILABLE
2685
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
2686
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2189
2687
  }
2190
2688
  }
2191
2689
 
2192
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
2690
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
2193
2691
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2692
+ constexpr int nwarps = mmq_get_nwarps_device();
2693
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2194
2694
 
2195
- #ifdef NEW_MMA_AVAILABLE
2695
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2196
2696
  int * x_qs = (int *) x_tile;
2197
- half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
2697
+ half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
2198
2698
  #else
2199
2699
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2200
2700
  int * x_qs = (int *) x_tile;
2201
2701
  half2 * x_ds = (half2 *) (x_qs + txs.qs);
2202
- #endif // NEW_MMA_AVAILABLE
2702
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2203
2703
 
2204
- const int kqsx = threadIdx.x % QI1_S;
2704
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
2705
+ constexpr int nrows = warp_size / threads_per_row;
2706
+ const int kqsx = threadIdx.x % threads_per_row;
2205
2707
 
2206
2708
  #pragma unroll
2207
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) {
2208
- int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S;
2709
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2710
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2209
2711
 
2210
2712
  if (need_check) {
2211
2713
  i = min(i, i_max);
@@ -2225,66 +2727,71 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2225
2727
  const int grid0 = (grid >> 0) & 0x0F0F0F0F;
2226
2728
  const int grid1 = (grid >> 4) & 0x0F0F0F0F;
2227
2729
 
2228
- #ifdef NEW_MMA_AVAILABLE
2730
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2229
2731
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
2230
2732
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
2231
2733
  #else
2232
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0;
2233
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1;
2234
- #endif // NEW_MMA_AVAILABLE
2734
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
2735
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
2736
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2235
2737
  }
2236
2738
 
2237
2739
  const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
2238
2740
  const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
2239
2741
 
2240
- #ifdef NEW_MMA_AVAILABLE
2241
- x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
2742
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2743
+ x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
2242
2744
  #else
2243
- x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
2244
- #endif // NEW_MMA_AVAILABLE
2745
+ x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
2746
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2245
2747
  }
2246
2748
  }
2247
2749
 
2248
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
2750
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
2249
2751
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2752
+ constexpr int nwarps = mmq_get_nwarps_device();
2753
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2250
2754
 
2251
- #ifdef NEW_MMA_AVAILABLE
2755
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2252
2756
  int * x_qs = (int *) x_tile;
2253
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2757
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2254
2758
  #else
2255
2759
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2256
2760
  int * x_qs = (int *) x_tile;
2257
2761
  float * x_df = (float *) (x_qs + txs.qs);
2258
- #endif // NEW_MMA_AVAILABLE
2762
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2259
2763
 
2260
- const int kbx = 0; // threadIdx.x / QI4_XS
2261
- const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
2764
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
2765
+ constexpr int nrows = warp_size / threads_per_row;
2766
+ const int kqsx = threadIdx.x % threads_per_row;
2262
2767
 
2263
2768
  #pragma unroll
2264
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
2265
- int i = i0 + threadIdx.y;
2769
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2770
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2266
2771
 
2267
2772
  if (need_check) {
2268
2773
  i = min(i, i_max);
2269
2774
  }
2270
2775
 
2271
- const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
2776
+ const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
2272
2777
 
2273
2778
  const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2274
- const int2 v = get_int_from_table_16(aux_q4);
2275
- const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2276
- #ifdef NEW_MMA_AVAILABLE
2779
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2780
+ const int k0 = 8 * (kqsx / 4) + kqsx % 4;
2781
+
2782
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2277
2783
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2278
2784
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2279
2785
  #else
2280
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2281
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2282
- #endif // NEW_MMA_AVAILABLE
2786
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2787
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
2788
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2283
2789
  }
2284
2790
 
2791
+ constexpr int rows_per_warp = warp_size / 8;
2285
2792
  #pragma unroll
2286
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
2287
- int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
2793
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
2794
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4);
2288
2795
 
2289
2796
  if (need_check) {
2290
2797
  i = min(i, i_max);
@@ -2297,18 +2804,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2297
2804
  const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
2298
2805
  | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
2299
2806
 
2300
- #ifdef NEW_MMA_AVAILABLE
2301
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
2807
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2808
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
2302
2809
  #else
2303
- x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2304
- #endif // NEW_MMA_AVAILABLE
2810
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2811
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2305
2812
  }
2306
2813
  }
2307
2814
 
2308
- template<int mmq_x, int mmq_y, int nwarps, bool need_check>
2815
+ template<int mmq_x, int mmq_y, bool need_check>
2309
2816
  static __device__ __forceinline__ void mmq_write_back_dp4a(
2310
2817
  const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
2311
2818
  const int stride, const int i_max, const int j_max) {
2819
+ constexpr int nwarps = mmq_get_nwarps_device();
2820
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2821
+
2312
2822
  #pragma unroll
2313
2823
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2314
2824
  const int j = j0 + threadIdx.y;
@@ -2318,32 +2828,42 @@ static __device__ __forceinline__ void mmq_write_back_dp4a(
2318
2828
  }
2319
2829
 
2320
2830
  #pragma unroll
2321
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
2831
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2322
2832
  const int i = i0 + threadIdx.x;
2323
2833
 
2324
2834
  if (need_check && i > i_max) {
2325
2835
  continue;
2326
2836
  }
2327
2837
 
2328
- dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
2838
+ dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
2329
2839
  }
2330
2840
  }
2331
2841
  }
2332
2842
 
2333
- template<int mmq_x, int mmq_y, int nwarps, bool need_check>
2843
+ template<ggml_type type, int mmq_x, int mmq_y, bool need_check>
2334
2844
  static __device__ __forceinline__ void mmq_write_back_mma(
2335
2845
  const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
2336
2846
  const int stride, const int i_max, const int j_max) {
2337
- typedef tile<16, 8, int> tile_C;
2338
2847
 
2339
2848
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
2849
+ constexpr int nwarps = mmq_get_nwarps_device();
2850
+
2851
+ #if defined(AMD_MFMA_AVAILABLE)
2852
+ constexpr int tileC_IJ = mmq_get_granularity_device(0);
2853
+ typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
2854
+ constexpr int rows_per_warp = granularity;
2855
+ #else
2856
+ typedef tile<16, 8, int> tile_C;
2340
2857
  constexpr int rows_per_warp = 2 * granularity;
2858
+ #endif // defined(AMD_MFMA_AVAILABLE)
2341
2859
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2342
2860
 
2343
2861
  const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
2344
- #ifdef NEW_MMA_AVAILABLE
2862
+ #if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
2345
2863
  static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
2346
- #endif // NEW_MMA_AVAILABLE
2864
+ #else
2865
+ GGML_UNUSED(nwarps);
2866
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2347
2867
 
2348
2868
  #pragma unroll
2349
2869
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
@@ -2371,179 +2891,189 @@ static __device__ __forceinline__ void mmq_write_back_mma(
2371
2891
 
2372
2892
  // -------------------------------------------------------------------------------------------------------------------------------------
2373
2893
 
2374
- template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
2894
+ template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
2375
2895
  struct mmq_type_traits;
2376
2896
 
2377
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2378
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
2897
+ template <int mmq_x, int mmq_y, bool need_check>
2898
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
2379
2899
  static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
2380
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
2381
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>;
2382
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2900
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, need_check>;
2901
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
2902
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;
2383
2903
  };
2384
2904
 
2385
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2386
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
2905
+ template <int mmq_x, int mmq_y, bool need_check>
2906
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
2387
2907
  static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
2388
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
2389
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2390
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2908
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, need_check>;
2909
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
2910
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>;
2391
2911
  };
2392
2912
 
2393
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2394
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
2913
+ template <int mmq_x, int mmq_y, bool need_check>
2914
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
2395
2915
  static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
2396
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
2397
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2398
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2916
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, need_check>;
2917
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2918
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2399
2919
  };
2400
2920
 
2401
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2402
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
2921
+ template <int mmq_x, int mmq_y, bool need_check>
2922
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
2403
2923
  static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
2404
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
2405
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2406
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2924
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, need_check>;
2925
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
2926
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
2407
2927
  };
2408
2928
 
2409
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2410
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
2929
+ template <int mmq_x, int mmq_y, bool need_check>
2930
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
2411
2931
  static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
2412
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
2413
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2414
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2932
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, need_check>;
2933
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2934
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2935
+ };
2936
+
2937
+ template <int mmq_x, int mmq_y, bool need_check>
2938
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
2939
+ static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
2940
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
2941
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2942
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2415
2943
  };
2416
2944
 
2417
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2418
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
2945
+ template <int mmq_x, int mmq_y, bool need_check>
2946
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
2419
2947
  static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
2420
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
2421
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
2422
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2948
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, need_check>;
2949
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;
2950
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>;
2423
2951
  };
2424
2952
 
2425
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2426
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
2953
+ template <int mmq_x, int mmq_y, bool need_check>
2954
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
2427
2955
  static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
2428
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
2429
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
2430
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2956
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, need_check>;
2957
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
2958
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>;
2431
2959
  };
2432
2960
 
2433
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2434
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
2961
+ template <int mmq_x, int mmq_y, bool need_check>
2962
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
2435
2963
  static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
2436
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
2437
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2438
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2964
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, need_check>;
2965
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
2966
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>;
2439
2967
  };
2440
2968
 
2441
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2442
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
2969
+ template <int mmq_x, int mmq_y, bool need_check>
2970
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
2443
2971
  static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
2444
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
2445
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2446
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2972
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, need_check>;
2973
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
2974
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>;
2447
2975
  };
2448
2976
 
2449
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2450
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
2977
+ template <int mmq_x, int mmq_y, bool need_check>
2978
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
2451
2979
  static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
2452
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
2453
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
2454
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2980
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, need_check>;
2981
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;
2982
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>;
2455
2983
  };
2456
2984
 
2457
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2458
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
2985
+ template <int mmq_x, int mmq_y, bool need_check>
2986
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
2459
2987
  static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
2460
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, nwarps, need_check>;
2461
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2462
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2988
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, need_check>;
2989
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2990
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2463
2991
  };
2464
2992
 
2465
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2466
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> {
2993
+ template <int mmq_x, int mmq_y, bool need_check>
2994
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
2467
2995
  static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
2468
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, nwarps, need_check>;
2469
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
2470
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2996
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, need_check>;
2997
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
2998
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
2471
2999
  };
2472
3000
 
2473
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2474
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> {
3001
+ template <int mmq_x, int mmq_y, bool need_check>
3002
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
2475
3003
  static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
2476
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, nwarps, need_check>;
2477
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
2478
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3004
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, need_check>;
3005
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3006
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
2479
3007
  };
2480
3008
 
2481
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2482
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> {
3009
+ template <int mmq_x, int mmq_y, bool need_check>
3010
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
2483
3011
  static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
2484
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, nwarps, need_check>;
2485
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2486
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3012
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, need_check>;
3013
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3014
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2487
3015
  };
2488
3016
 
2489
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2490
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> {
3017
+ template <int mmq_x, int mmq_y, bool need_check>
3018
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
2491
3019
  static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
2492
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, nwarps, need_check>;
2493
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2494
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3020
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, need_check>;
3021
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3022
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2495
3023
  };
2496
3024
 
2497
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2498
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> {
3025
+ template <int mmq_x, int mmq_y, bool need_check>
3026
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
2499
3027
  static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
2500
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, nwarps, need_check>;
2501
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2502
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3028
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>;
3029
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3030
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
2503
3031
  };
2504
3032
 
2505
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2506
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
3033
+ template <int mmq_x, int mmq_y, bool need_check>
3034
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
2507
3035
  static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
2508
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
2509
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2510
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3036
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, need_check>;
3037
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3038
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2511
3039
  };
2512
3040
 
2513
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2514
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
3041
+ template <int mmq_x, int mmq_y, bool need_check>
3042
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
2515
3043
  static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
2516
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
2517
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2518
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3044
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, need_check>;
3045
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3046
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2519
3047
  };
2520
3048
 
2521
- template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
3049
+ template <ggml_type type, int mmq_x, bool need_check, bool fixup>
2522
3050
  static __device__ __forceinline__ void mul_mat_q_process_tile(
2523
3051
  const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
2524
3052
  const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2525
3053
  const int stride_row_x, const int ncols_y, const int stride_col_dst,
2526
3054
  const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
2527
3055
 
3056
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3057
+ constexpr int nwarps = mmq_get_nwarps_device();
2528
3058
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2529
3059
  constexpr int mmq_y = get_mmq_y_device();
2530
- constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
3060
+ constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles;
2531
3061
 
2532
3062
  extern __shared__ int data_mul_mat_q[];
2533
3063
  int * tile_y = data_mul_mat_q + mmq_x;
2534
- int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
3064
+ int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
2535
3065
 
2536
- #ifdef NEW_MMA_AVAILABLE
2537
- constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
2538
- constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
3066
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3067
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
3068
+ constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
2539
3069
  #else
2540
- constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
2541
- constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
2542
- #endif // NEW_MMA_AVAILABLE
3070
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
3071
+ constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
3072
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2543
3073
 
2544
3074
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
2545
3075
 
2546
- float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
3076
+ float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
2547
3077
 
2548
3078
  for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
2549
3079
  load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
@@ -2551,8 +3081,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
2551
3081
  {
2552
3082
  const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
2553
3083
  #pragma unroll
2554
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
2555
- int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3084
+ for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
3085
+ int l = l0 + threadIdx.y*warp_size + threadIdx.x;
2556
3086
 
2557
3087
  tile_y[l] = by0[l];
2558
3088
  }
@@ -2567,8 +3097,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
2567
3097
  {
2568
3098
  const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
2569
3099
  #pragma unroll
2570
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
2571
- int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3100
+ for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
3101
+ int l = l0 + threadIdx.y*warp_size + threadIdx.x;
2572
3102
 
2573
3103
  tile_y[l] = by0[l];
2574
3104
  }
@@ -2576,7 +3106,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
2576
3106
 
2577
3107
  __syncthreads();
2578
3108
 
2579
- vec_dot(tile_x, tile_y, sum, WARP_SIZE);
3109
+ vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K);
2580
3110
 
2581
3111
  __syncthreads();
2582
3112
  }
@@ -2591,24 +3121,25 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
2591
3121
 
2592
3122
  // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
2593
3123
 
2594
- template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2595
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
3124
+ template <ggml_type type, int mmq_x, bool need_check>
3125
+ #if defined(GGML_USE_HIP)
2596
3126
  #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2597
- __launch_bounds__(WARP_SIZE*nwarps, 2)
3127
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
2598
3128
  #endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2599
3129
  #else
2600
3130
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
2601
- __launch_bounds__(WARP_SIZE*nwarps, 1)
3131
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1)
2602
3132
  #else
2603
- __launch_bounds__(WARP_SIZE*nwarps, 2)
3133
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
2604
3134
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
2605
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
3135
+ #endif // defined(GGML_USE_HIP)
2606
3136
  static __global__ void mul_mat_q(
2607
3137
  const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
2608
3138
  const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2609
3139
  const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
2610
3140
  const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
2611
- const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
3141
+ const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
3142
+ const int ncols_max) {
2612
3143
 
2613
3144
  // Skip unused template specializations for faster compilation:
2614
3145
  if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
@@ -2616,10 +3147,13 @@ static __global__ void mul_mat_q(
2616
3147
  return;
2617
3148
  }
2618
3149
 
3150
+ constexpr int nwarps = mmq_get_nwarps_device();
3151
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3152
+
2619
3153
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2620
3154
  constexpr int mmq_y = get_mmq_y_device();
2621
3155
 
2622
- const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
3156
+ const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
2623
3157
  const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
2624
3158
 
2625
3159
  // Initialize the ids for writing back data with just the index.
@@ -2627,10 +3161,10 @@ static __global__ void mul_mat_q(
2627
3161
  // For MoE the correct indices are loaded from ids_dst.
2628
3162
  extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
2629
3163
  #pragma unroll
2630
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2631
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3164
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3165
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
2632
3166
 
2633
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
3167
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
2634
3168
  break;
2635
3169
  }
2636
3170
 
@@ -2638,8 +3172,8 @@ static __global__ void mul_mat_q(
2638
3172
  }
2639
3173
  __syncthreads();
2640
3174
 
2641
- // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
2642
- #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3175
+ // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
3176
+ #if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
2643
3177
  {
2644
3178
  const int wt = blockIdx.z / nchannels_y;
2645
3179
  const int zt = blockIdx.z - wt*nchannels_y;
@@ -2667,10 +3201,10 @@ static __global__ void mul_mat_q(
2667
3201
 
2668
3202
  // __syncthreads(); // There is no previous tile that could cause a race condition.
2669
3203
  #pragma unroll
2670
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2671
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3204
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3205
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
2672
3206
 
2673
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
3207
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
2674
3208
  break;
2675
3209
  }
2676
3210
 
@@ -2688,12 +3222,12 @@ static __global__ void mul_mat_q(
2688
3222
  const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
2689
3223
 
2690
3224
  constexpr bool fixup = false;
2691
- mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
3225
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
2692
3226
  (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
2693
3227
  tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
2694
3228
  return;
2695
3229
  }
2696
- #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3230
+ #endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
2697
3231
 
2698
3232
  const int64_t blocks_per_ne00 = ncols_x / qk;
2699
3233
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
@@ -2745,10 +3279,10 @@ static __global__ void mul_mat_q(
2745
3279
 
2746
3280
  __syncthreads();
2747
3281
  #pragma unroll
2748
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2749
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3282
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3283
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
2750
3284
 
2751
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
3285
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
2752
3286
  break;
2753
3287
  }
2754
3288
 
@@ -2766,7 +3300,7 @@ static __global__ void mul_mat_q(
2766
3300
  const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
2767
3301
 
2768
3302
  constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
2769
- mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
3303
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
2770
3304
  (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
2771
3305
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2772
3306
 
@@ -2812,10 +3346,10 @@ static __global__ void mul_mat_q(
2812
3346
  // The memory layout for the fixup buffer is always contiguous, therefore reset ids:
2813
3347
  __syncthreads();
2814
3348
  #pragma unroll
2815
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2816
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3349
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3350
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
2817
3351
 
2818
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
3352
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
2819
3353
  break;
2820
3354
  }
2821
3355
 
@@ -2833,25 +3367,29 @@ static __global__ void mul_mat_q(
2833
3367
  const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
2834
3368
 
2835
3369
  constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
2836
- mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
3370
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
2837
3371
  (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
2838
3372
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2839
3373
  }
2840
3374
 
2841
3375
 
2842
- template <ggml_type type, int mmq_x, int nwarps, bool need_check>
3376
+ template <ggml_type type, int mmq_x, bool need_check>
2843
3377
  static __global__ void mul_mat_q_stream_k_fixup(
2844
3378
  const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
2845
3379
  const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
2846
- const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
3380
+ const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
3381
+ const int ncols_max) {
2847
3382
  constexpr int mmq_y = get_mmq_y_device();
2848
3383
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2849
3384
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
2850
3385
  const int64_t blocks_per_ne00 = ncols_x / qk;
2851
3386
 
2852
- float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
3387
+ constexpr int nwarps = mmq_get_nwarps_device();
3388
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2853
3389
 
2854
- const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
3390
+ float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
3391
+
3392
+ const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
2855
3393
  const int nty = (nrows_x + mmq_y - 1) / mmq_y;
2856
3394
 
2857
3395
  const int bidx0 = blockIdx.x;
@@ -2893,10 +3431,10 @@ static __global__ void mul_mat_q_stream_k_fixup(
2893
3431
  const int j = j0 + threadIdx.y;
2894
3432
 
2895
3433
  #pragma unroll
2896
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
3434
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2897
3435
  const int i = i0 + threadIdx.x;
2898
3436
 
2899
- sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
3437
+ sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
2900
3438
  }
2901
3439
  }
2902
3440
 
@@ -2937,14 +3475,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
2937
3475
  }
2938
3476
 
2939
3477
  #pragma unroll
2940
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
3478
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2941
3479
  const int i = i0 + threadIdx.x;
2942
3480
 
2943
3481
  if (need_check && i > i_max) {
2944
3482
  continue;
2945
3483
  }
2946
3484
 
2947
- dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
3485
+ dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
2948
3486
  }
2949
3487
  }
2950
3488
  return;
@@ -2955,7 +3493,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
2955
3493
  const int col_high = expert_bounds[zt + 1];
2956
3494
  const int col_diff = col_high - col_low;
2957
3495
 
2958
- for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) {
3496
+ for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
2959
3497
  ids_dst_shared[j] = ids_dst[col_low + j];
2960
3498
  }
2961
3499
  __syncthreads();
@@ -2975,14 +3513,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
2975
3513
  }
2976
3514
 
2977
3515
  #pragma unroll
2978
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
3516
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2979
3517
  const int i = i0 + threadIdx.x;
2980
3518
 
2981
3519
  if (need_check && i > i_max) {
2982
3520
  continue;
2983
3521
  }
2984
3522
 
2985
- dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
3523
+ dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
2986
3524
  }
2987
3525
  }
2988
3526
  }
@@ -2992,17 +3530,17 @@ struct mmq_args {
2992
3530
  int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
2993
3531
  int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
2994
3532
  int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
2995
- bool use_stream_k;
3533
+ bool use_stream_k; int64_t ncols_max;
2996
3534
  };
2997
3535
 
2998
3536
  template<ggml_type type>
2999
- static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) {
3537
+ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
3000
3538
  const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
3001
3539
  const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
3002
3540
  const size_t nbs_ids = mmq_x*sizeof(int);
3003
- const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3541
+ const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3004
3542
  const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
3005
- return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
3543
+ return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
3006
3544
  }
3007
3545
 
3008
3546
  template <ggml_type type, int mmq_x>
@@ -3010,23 +3548,19 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3010
3548
  const int id = ggml_cuda_get_device();
3011
3549
  const int cc = ggml_cuda_info().devices[id].cc;
3012
3550
  const int nsm = ggml_cuda_info().devices[id].nsm;
3551
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
3552
+ const int nwarps = mmq_get_nwarps_host(cc, warp_size);
3013
3553
  const int mmq_y = get_mmq_y_host(cc);
3014
3554
 
3015
- const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
3555
+ const dim3 block_dims(warp_size, nwarps, 1);
3016
3556
 
3017
- const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
3557
+ const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);
3018
3558
 
3019
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3020
- static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
3021
- if (!shared_memory_limit_raised[id]) {
3022
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3023
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3024
- shared_memory_limit_raised[id] = true;
3025
- }
3026
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3559
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared);
3560
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
3027
3561
 
3028
3562
  const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
3029
- const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
3563
+ const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
3030
3564
  const int ntzw = args.nchannels_y * args.nsamples_y;
3031
3565
  const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
3032
3566
 
@@ -3038,18 +3572,20 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3038
3572
  if (!args.use_stream_k) {
3039
3573
  if (args.nrows_x % mmq_y == 0) {
3040
3574
  constexpr bool need_check = false;
3041
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3575
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3042
3576
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3043
3577
  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3044
3578
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3045
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3579
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3580
+ args.ncols_max);
3046
3581
  } else {
3047
3582
  constexpr bool need_check = true;
3048
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3583
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3049
3584
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3050
3585
  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3051
3586
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3052
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3587
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3588
+ args.ncols_max);
3053
3589
  }
3054
3590
  return;
3055
3591
  }
@@ -3065,44 +3601,48 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3065
3601
 
3066
3602
  if (args.nrows_x % mmq_y == 0) {
3067
3603
  constexpr bool need_check = false;
3068
-
3069
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3604
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3070
3605
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3071
3606
  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3072
3607
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3073
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3608
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3609
+ args.ncols_max);
3074
3610
 
3075
3611
  if (!fixup_needed) {
3076
3612
  return;
3077
3613
  }
3078
3614
 
3079
- mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3615
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3080
3616
  (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3081
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
3617
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3618
+ args.ncols_max);
3082
3619
  } else {
3083
3620
  constexpr bool need_check = true;
3084
-
3085
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3621
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3086
3622
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3087
3623
  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3088
3624
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3089
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3625
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3626
+ args.ncols_max);
3090
3627
 
3091
3628
  if (!fixup_needed) {
3092
3629
  return;
3093
3630
  }
3094
3631
 
3095
- mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3632
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3096
3633
  (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3097
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
3634
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3635
+ args.ncols_max);
3098
3636
  }
3099
3637
  }
3100
3638
 
3101
3639
  template <ggml_type type>
3102
3640
  void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
3103
- const int id = ggml_cuda_get_device();
3104
- const int cc = ggml_cuda_info().devices[id].cc;
3105
- const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
3641
+ const int id = ggml_cuda_get_device();
3642
+ const int cc = ggml_cuda_info().devices[id].cc;
3643
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
3644
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
3645
+ const int nwarps = mmq_get_nwarps_host(cc, warp_size);
3106
3646
 
3107
3647
  const int mmq_x_max = get_mmq_x_max_host(cc);
3108
3648
  const int mmq_y = get_mmq_y_host(cc);
@@ -3113,11 +3653,11 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
3113
3653
  for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
3114
3654
  const int granularity = mmq_get_granularity_host(mmq_x, cc);
3115
3655
 
3116
- if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc) > smpbo) {
3656
+ if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {
3117
3657
  continue;
3118
3658
  }
3119
3659
 
3120
- const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
3660
+ const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
3121
3661
 
3122
3662
  if (ntiles_x < ntiles_x_best) {
3123
3663
  mmq_x_best = mmq_x;
@@ -3189,6 +3729,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
3189
3729
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
3190
3730
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
3191
3731
  extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
3732
+ extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
3192
3733
  extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
3193
3734
  extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
3194
3735
  extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);