@novastera-oss/llamarn 0.2.9 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  22. package/cpp/llama.cpp/CMakePresets.json +11 -0
  23. package/cpp/llama.cpp/CODEOWNERS +1 -0
  24. package/cpp/llama.cpp/README.md +8 -8
  25. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  26. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  27. package/cpp/llama.cpp/common/arg.cpp +62 -1
  28. package/cpp/llama.cpp/common/chat.cpp +37 -20
  29. package/cpp/llama.cpp/common/chat.h +2 -0
  30. package/cpp/llama.cpp/common/common.cpp +22 -6
  31. package/cpp/llama.cpp/common/common.h +22 -4
  32. package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
  33. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
  34. package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
  35. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  36. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  37. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  38. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  41. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
  42. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  46. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  97. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
  99. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
  100. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  131. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
  132. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  138. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  139. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  140. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  142. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  143. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  144. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  173. package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  177. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  178. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  179. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
  180. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  181. package/cpp/llama.cpp/include/llama.h +15 -47
  182. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  183. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  184. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  185. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  186. package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
  187. package/cpp/llama.cpp/src/llama-arch.h +23 -1
  188. package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
  189. package/cpp/llama.cpp/src/llama-batch.h +31 -18
  190. package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
  191. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  192. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  193. package/cpp/llama.cpp/src/llama-context.h +26 -16
  194. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  195. package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
  196. package/cpp/llama.cpp/src/llama-graph.h +184 -122
  197. package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
  198. package/cpp/llama.cpp/src/llama-hparams.h +13 -2
  199. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  200. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  201. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  202. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  203. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  204. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  205. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  206. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
  207. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  208. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  209. package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
  210. package/cpp/llama.cpp/src/llama-model.h +21 -4
  211. package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
  212. package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
  213. package/cpp/llama.cpp/src/llama-vocab.h +43 -0
  214. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  215. package/cpp/llama.cpp/src/unicode.h +2 -0
  216. package/ios/include/chat.h +2 -0
  217. package/ios/include/common.h +22 -4
  218. package/ios/include/llama.h +15 -47
  219. package/ios/libs/llama.xcframework/Info.plist +13 -13
  220. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  221. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  223. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  224. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
  225. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  231. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  232. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  250. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  254. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  255. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  261. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  262. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
  263. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  264. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  265. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  267. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  268. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
  269. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
  270. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  274. package/package.json +4 -4
  275. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  276. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  277. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  278. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  279. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  280. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -90,7 +90,7 @@ struct tile_x_sizes {
90
90
  };
91
91
 
92
92
  static int get_mmq_x_max_host(const int cc) {
93
- return new_mma_available(cc) ? 128 :
93
+ return (amd_mfma_available(cc) || new_mma_available(cc)) ? 128 :
94
94
  GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
95
95
  #ifdef GGML_CUDA_FORCE_MMQ
96
96
  128 : 64;
@@ -100,12 +100,12 @@ static int get_mmq_x_max_host(const int cc) {
100
100
  }
101
101
 
102
102
  static constexpr __device__ int get_mmq_x_max_device() {
103
- #ifdef NEW_MMA_AVAILABLE
103
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
104
104
  return 128;
105
- #else // NEW_MMA_AVAILABLE
105
+ #else // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
106
106
 
107
107
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
108
- return 128;
108
+ return 64;
109
109
  #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
110
110
 
111
111
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
@@ -115,12 +115,11 @@ static constexpr __device__ int get_mmq_x_max_device() {
115
115
  return MMQ_DP4A_MAX_BATCH_SIZE;
116
116
  #endif // GGML_CUDA_FORCE_MMQ
117
117
  #else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
118
-
119
118
  return 64;
120
119
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
121
120
 
122
121
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
123
- #endif // NEW_MMA_AVAILABLE
122
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
124
123
  }
125
124
 
126
125
  static int get_mmq_y_host(const int cc) {
@@ -144,16 +143,25 @@ static constexpr __device__ int get_mmq_y_device() {
144
143
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
145
144
  }
146
145
 
147
- #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
148
- #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
149
- #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
150
- #define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0}
151
- #define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
152
- #define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
153
- #define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8}
154
- #define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
155
- #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
156
- #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
146
+ // Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
147
+ // The K dimension of the tiles has either,
148
+ // 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
149
+ // 32 bit elements for the quantized data (does not include scales).
150
+ // In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
151
+ // The final tile size in K direction is padded to avoid shared memory bank conflicts,
152
+ // in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
153
+ #define MMQ_TILE_NE_K 32
154
+
155
+ #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0}
156
+ #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0}
157
+ #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}
158
+ #define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0}
159
+ #define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0}
160
+ #define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0}
161
+ #define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
162
+ #define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
163
+ #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
164
+ #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
157
165
 
158
166
  static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
159
167
  switch (type) {
@@ -179,11 +187,11 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
179
187
  }
180
188
  }
181
189
 
182
- #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
183
- #define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
184
- #define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4)
185
- #define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4)
186
- #define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
190
+ #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
191
+ #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
192
+ #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
193
+ #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
194
+ #define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
187
195
 
188
196
  static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
189
197
  static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
@@ -215,42 +223,80 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
215
223
  }
216
224
  }
217
225
 
218
- #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
226
+ // block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
227
+ #define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
219
228
 
220
229
  static int mmq_get_granularity_host(const int mmq_x, const int cc) {
221
- return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
230
+ if (amd_mfma_available(cc)) {
231
+ return mmq_x >= 128 ? 32 : 16;
232
+ } else if (new_mma_available(cc) && mmq_x >= 48) {
233
+ return 16;
234
+ } else {
235
+ return 8;
236
+ }
222
237
  }
223
238
 
224
- #ifdef NEW_MMA_AVAILABLE
239
+ #if defined(AMD_MFMA_AVAILABLE)
240
+ static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
241
+ return mmq_x >= 128 ? 32 : 16;
242
+ }
243
+ #elif defined(NEW_MMA_AVAILABLE)
225
244
  static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
226
245
  return mmq_x >= 48 ? 16 : 8;
227
246
  }
228
247
  #else
229
- static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
248
+ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) {
230
249
  return 8;
231
250
  }
232
- #endif // NEW_MMA_AVAILABLE
251
+ #endif // AMD_MFMA_AVAILABLE
252
+
253
+ #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
254
+ static int mmq_get_nwarps_host(const int cc) {
255
+ return amd_mfma_available(cc) ? 8 : 4;
256
+ }
257
+ #else
258
+ static int mmq_get_nwarps_host(const int /*cc*/) {
259
+ return 8;
260
+ }
261
+ #endif // (GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
262
+
263
+ static constexpr __device__ int mmq_get_nwarps_device() {
264
+ #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
265
+ #if defined(AMD_MFMA_AVAILABLE)
266
+ return 8;
267
+ #else
268
+ return 4;
269
+ #endif // AMD_MFMA_AVAILABLE
270
+ #else
271
+ return 8;
272
+ #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
273
+ }
233
274
 
234
275
  // ------------------------------------------------------------
235
276
 
236
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
277
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
237
278
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
279
+ constexpr int nwarps = mmq_get_nwarps_device();
280
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
238
281
 
239
- #ifdef NEW_MMA_AVAILABLE
282
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
240
283
  int * x_qs = (int *) x_tile;
241
- float * x_df = (float *) (x_qs + 2*WARP_SIZE);
284
+ float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
242
285
  #else
243
286
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
244
287
  int * x_qs = (int *) x_tile;
245
288
  float * x_df = (float *) (x_qs + txs.qs);
246
- #endif // NEW_MMA_AVAILABLE
289
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
247
290
 
248
- const int kbx = threadIdx.x / QI4_0;
249
- const int kqsx = threadIdx.x % QI4_0;
291
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
292
+ constexpr int nrows = warp_size / threads_per_row;
293
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
294
+ const int kbx = txi / QI4_0;
295
+ const int kqsx = txi % QI4_0;
250
296
 
251
297
  #pragma unroll
252
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
253
- int i = i0 + threadIdx.y;
298
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
299
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
254
300
 
255
301
  if (need_check) {
256
302
  i = min(i, i_max);
@@ -259,20 +305,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
259
305
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
260
306
  const int qs0 = get_int_b2(bxi->qs, kqsx);
261
307
 
262
- #ifdef NEW_MMA_AVAILABLE
308
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
263
309
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
264
310
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
265
311
  #else
266
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
267
- #endif // NEW_MMA_AVAILABLE
312
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
313
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
268
314
  }
269
315
 
270
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
316
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
317
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
271
318
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
272
319
 
273
320
  #pragma unroll
274
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
275
- int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
321
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
322
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
276
323
 
277
324
  if (need_check) {
278
325
  i = min(i, i_max);
@@ -280,17 +327,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
280
327
 
281
328
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
282
329
 
283
- #ifdef NEW_MMA_AVAILABLE
284
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
330
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
331
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
285
332
  #else
286
- x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
287
- #endif // NEW_MMA_AVAILABLE
333
+ x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
334
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
288
335
  }
289
336
  }
290
337
 
291
- template <int mmq_x, int mmq_y, int nwarps>
338
+ template <int mmq_x, int mmq_y>
292
339
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
293
340
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
341
+ constexpr int nwarps = mmq_get_nwarps_device();
342
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
294
343
 
295
344
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
296
345
  const int * x_qs = (const int *) x;
@@ -299,7 +348,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
299
348
  const half2 * y_ds = (const half2 *) y;
300
349
 
301
350
  // #pragma unroll
302
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
351
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
303
352
  const int k0 = k00 + k01;
304
353
 
305
354
  #pragma unroll
@@ -307,7 +356,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
307
356
  const int j = j0 + threadIdx.y;
308
357
 
309
358
  #pragma unroll
310
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
359
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
311
360
  const int i = i0 + threadIdx.x;
312
361
 
313
362
  const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
@@ -320,32 +369,37 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
320
369
  u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
321
370
  }
322
371
 
323
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
324
- (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u,
325
- x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
372
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
373
+ (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
374
+ x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
326
375
  }
327
376
  }
328
377
  }
329
378
  }
330
379
 
331
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
380
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
332
381
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
382
+ constexpr int nwarps = mmq_get_nwarps_device();
383
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
333
384
 
334
- #ifdef NEW_MMA_AVAILABLE
385
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
335
386
  int * x_qs = (int *) x_tile;
336
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
387
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
337
388
  #else
338
389
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
339
390
  int * x_qs = (int *) x_tile;
340
391
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
341
- #endif // NEW_MMA_AVAILABLE
392
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
342
393
 
343
- const int kbx = threadIdx.x / QI4_1;
344
- const int kqsx = threadIdx.x % QI4_1;
394
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
395
+ constexpr int nrows = warp_size / threads_per_row;
396
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
397
+ const int kbx = txi / QI4_1;
398
+ const int kqsx = txi % QI4_1;
345
399
 
346
400
  #pragma unroll
347
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
348
- int i = i0 + threadIdx.y;
401
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
402
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
349
403
 
350
404
  if (need_check) {
351
405
  i = min(i, i_max);
@@ -354,20 +408,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
354
408
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
355
409
  const int qs0 = get_int_b4(bxi->qs, kqsx);
356
410
 
357
- #ifdef NEW_MMA_AVAILABLE
411
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
358
412
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
359
413
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
360
414
  #else
361
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
362
- #endif // NEW_MMA_AVAILABLE
415
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
416
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
363
417
  }
364
418
 
365
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
419
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
420
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
366
421
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
367
422
 
368
423
  #pragma unroll
369
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
370
- int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
424
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
425
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
371
426
 
372
427
  if (need_check) {
373
428
  i = min(i, i_max);
@@ -375,17 +430,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
375
430
 
376
431
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
377
432
 
378
- #ifdef NEW_MMA_AVAILABLE
379
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
433
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
434
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
380
435
  #else
381
- x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
382
- #endif // NEW_MMA_AVAILABLE
436
+ x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
437
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
383
438
  }
384
439
  }
385
440
 
386
- template <int mmq_x, int mmq_y, int nwarps>
441
+ template <int mmq_x, int mmq_y>
387
442
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
388
443
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
444
+ constexpr int nwarps = mmq_get_nwarps_device();
445
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
389
446
 
390
447
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
391
448
  const int * x_qs = (const int *) x;
@@ -394,7 +451,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
394
451
  const half2 * y_ds = (const half2 *) y;
395
452
 
396
453
  // #pragma unroll
397
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
454
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
398
455
  const int k0 = k00 + k01;
399
456
 
400
457
  #pragma unroll
@@ -402,7 +459,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
402
459
  const int j = j0 + threadIdx.y;
403
460
 
404
461
  #pragma unroll
405
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
462
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
406
463
  const int i = i0 + threadIdx.x;
407
464
 
408
465
  const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
@@ -415,32 +472,37 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
415
472
  u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
416
473
  }
417
474
 
418
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
419
- (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u,
420
- x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
475
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
476
+ (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
477
+ x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
421
478
  }
422
479
  }
423
480
  }
424
481
  }
425
482
 
426
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
483
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
427
484
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
485
+ constexpr int nwarps = mmq_get_nwarps_device();
486
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
428
487
 
429
- #ifdef NEW_MMA_AVAILABLE
488
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
430
489
  int * x_qs = (int *) x_tile;
431
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
490
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
432
491
  #else
433
492
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
434
493
  int * x_qs = (int *) x_tile;
435
494
  float * x_df = (float *) (x_qs + txs.qs);
436
- #endif // NEW_MMA_AVAILABLE
495
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
437
496
 
438
- const int kbx = threadIdx.x / QI5_0;
439
- const int kqsx = threadIdx.x % QI5_0;
497
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
498
+ constexpr int nrows = warp_size / threads_per_row;
499
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
500
+ const int kbx = txi / QI5_0;
501
+ const int kqsx = txi % QI5_0;
440
502
 
441
503
  #pragma unroll
442
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
443
- int i = i0 + threadIdx.y;
504
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
505
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
444
506
 
445
507
  if (need_check) {
446
508
  i = min(i, i_max);
@@ -449,7 +511,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
449
511
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
450
512
 
451
513
  const int ql = get_int_b2(bxi->qs, kqsx);
452
- const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
514
+ const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx);
453
515
 
454
516
  int qs0 = (ql >> 0) & 0x0F0F0F0F;
455
517
  qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
@@ -465,21 +527,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
465
527
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
466
528
  qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
467
529
 
468
- #ifdef NEW_MMA_AVAILABLE
530
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
469
531
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
470
532
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
471
533
  #else
472
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
473
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
474
- #endif // NEW_MMA_AVAILABLE
534
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
535
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
536
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
475
537
  }
476
538
 
477
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
539
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
540
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
478
541
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
479
542
 
480
543
  #pragma unroll
481
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
482
- int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
544
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
545
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
483
546
 
484
547
  if (need_check) {
485
548
  i = min(i, i_max);
@@ -487,32 +550,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
487
550
 
488
551
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
489
552
 
490
- #ifdef NEW_MMA_AVAILABLE
491
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
553
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
554
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
492
555
  #else
493
- x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
494
- #endif // NEW_MMA_AVAILABLE
556
+ x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
557
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
495
558
  }
496
559
  }
497
560
 
498
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
561
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
499
562
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
563
+ constexpr int nwarps = mmq_get_nwarps_device();
564
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
500
565
 
501
- #ifdef NEW_MMA_AVAILABLE
566
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
502
567
  int * x_qs = (int *) x_tile;
503
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
568
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
504
569
  #else
505
570
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
506
571
  int * x_qs = (int *) x_tile;
507
572
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
508
- #endif // NEW_MMA_AVAILABLE
573
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
509
574
 
510
- const int kbx = threadIdx.x / QI5_1;
511
- const int kqsx = threadIdx.x % QI5_1;
575
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
576
+ constexpr int nrows = warp_size / threads_per_row;
577
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
578
+ const int kbx = txi / QI5_1;
579
+ const int kqsx = txi % QI5_1;
512
580
 
513
581
  #pragma unroll
514
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
515
- int i = i0 + threadIdx.y;
582
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
583
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
516
584
 
517
585
  if (need_check) {
518
586
  i = min(i, i_max);
@@ -521,7 +589,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
521
589
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
522
590
 
523
591
  const int ql = get_int_b4(bxi->qs, kqsx);
524
- const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
592
+ const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx);
525
593
 
526
594
  int qs0 = (ql >> 0) & 0x0F0F0F0F;
527
595
  qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
@@ -535,21 +603,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
535
603
  qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
536
604
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
537
605
 
538
- #ifdef NEW_MMA_AVAILABLE
606
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
539
607
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
540
608
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
541
609
  #else
542
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
543
- x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
544
- #endif // NEW_MMA_AVAILABLE
610
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
611
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
612
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
545
613
  }
546
614
 
547
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
615
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
616
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
548
617
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
549
618
 
550
619
  #pragma unroll
551
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
552
- int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
620
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
621
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
553
622
 
554
623
  if (need_check) {
555
624
  i = min(i, i_max);
@@ -557,32 +626,38 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
557
626
 
558
627
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
559
628
 
560
- #ifdef NEW_MMA_AVAILABLE
561
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
629
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
630
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
562
631
  #else
563
- x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
564
- #endif // NEW_MMA_AVAILABLE
632
+ x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
633
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
565
634
  }
566
635
  }
567
636
 
568
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
637
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
569
638
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
639
+ constexpr int nwarps = mmq_get_nwarps_device();
640
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
570
641
 
571
- #ifdef NEW_MMA_AVAILABLE
642
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
572
643
  int * x_qs = (int *) x_tile;
573
- float * x_df = (float *) (x_tile + 2*WARP_SIZE);
644
+ float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
574
645
  #else
575
646
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
576
647
  int * x_qs = (int *) x_tile;
577
648
  float * x_df = (float *) (x_qs + txs.qs);
578
- #endif // NEW_MMA_AVAILABLE
649
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
579
650
 
580
- const int kbx = threadIdx.x / QI8_0;
581
- const int kqsx = threadIdx.x % QI8_0;
651
+ // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
652
+ constexpr int threads_per_row = 32;
653
+ constexpr int nrows = warp_size / threads_per_row;
654
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
655
+ const int kbx = txi / QI8_0;
656
+ const int kqsx = txi % QI8_0;
582
657
 
583
658
  #pragma unroll
584
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
585
- int i = i0 + threadIdx.y;
659
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
660
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
586
661
 
587
662
  if (need_check) {
588
663
  i = min(i, i_max);
@@ -590,21 +665,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
590
665
 
591
666
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
592
667
 
593
- #ifdef NEW_MMA_AVAILABLE
594
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
595
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
668
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
669
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
670
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
596
671
  #else
597
- x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
598
- x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
599
- #endif // NEW_MMA_AVAILABLE
672
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
673
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
674
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
600
675
  }
601
676
 
602
- const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
677
+ constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
678
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
603
679
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
604
680
 
605
681
  #pragma unroll
606
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) {
607
- int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row;
682
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
683
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
608
684
 
609
685
  if (need_check) {
610
686
  i = min(i, i_max);
@@ -612,17 +688,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
612
688
 
613
689
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
614
690
 
615
- #ifdef NEW_MMA_AVAILABLE
616
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
691
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
692
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
617
693
  #else
618
- x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
619
- #endif // NEW_MMA_AVAILABLE
694
+ x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
695
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
620
696
  }
621
697
  }
622
698
 
623
- template <int mmq_x, int mmq_y, int nwarps>
699
+ template <int mmq_x, int mmq_y>
624
700
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
625
701
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
702
+ constexpr int nwarps = mmq_get_nwarps_device();
703
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
626
704
 
627
705
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
628
706
  const int * x_qs = (const int *) x;
@@ -631,7 +709,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
631
709
  const float * y_df = (const float *) y;
632
710
 
633
711
  // #pragma unroll
634
- for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
712
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
635
713
  const int k0 = k00 + k01;
636
714
 
637
715
  #pragma unroll
@@ -639,21 +717,76 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
639
717
  const int j = j0 + threadIdx.y;
640
718
 
641
719
  #pragma unroll
642
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
720
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
643
721
  const int i = i0 + threadIdx.x;
644
722
 
645
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
646
- (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE],
647
- x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]);
723
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
724
+ (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],
725
+ x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]);
648
726
  }
649
727
  }
650
728
  }
651
729
  }
652
730
 
653
- template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
731
+ template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
654
732
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
655
733
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
734
+ #if defined(AMD_MFMA_AVAILABLE)
735
+ typedef tile<16, 8, int> tile_A;
736
+ typedef tile<16, 8, int> tile_B;
737
+ typedef tile<16, 16, int> tile_C;
738
+
739
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
740
+ constexpr int rows_per_warp = granularity;
741
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
656
742
 
743
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
744
+
745
+ const int * x_qs = (const int *) x;
746
+ const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
747
+ const int * y_qs = (const int *) y + 4;
748
+ const float * y_df = (const float *) y;
749
+ const half2 * y_ds = (const half2 *) y;
750
+
751
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
752
+
753
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
754
+ const int k0 = k00 + k01;
755
+
756
+ tile_A A[ntx];
757
+ #pragma unroll
758
+ for (int n = 0; n < ntx; ++n) {
759
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
760
+ }
761
+
762
+ #pragma unroll
763
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
764
+ tile_B B;
765
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
766
+
767
+ float dB;
768
+ const int j = j0 + tile_C::get_j(0);
769
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
770
+ dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
771
+ } else {
772
+ dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
773
+ }
774
+
775
+ #pragma unroll
776
+ for (int n = 0; n < ntx; ++n) {
777
+ tile_C C;
778
+ mma(C, A[n], B);
779
+
780
+ #pragma unroll
781
+ for (int l = 0; l < tile_C::ne; ++l) {
782
+ const int i = i0 + n*tile_A::I + tile_C::get_i(l);
783
+ const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
784
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;
785
+ }
786
+ }
787
+ }
788
+ }
789
+ #else
657
790
  typedef tile<16, 8, int> tile_A;
658
791
  typedef tile< 8, 8, int> tile_B;
659
792
  typedef tile<16, 8, int> tile_C;
@@ -662,23 +795,23 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
662
795
  constexpr int rows_per_warp = 2 * granularity;
663
796
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
664
797
 
665
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
798
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
666
799
 
667
800
  const int * x_qs = (const int *) x;
668
- const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
801
+ const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
669
802
  const int * y_qs = (const int *) y + 4;
670
803
  const float * y_df = (const float *) y;
671
804
  const half2 * y_ds = (const half2 *) y;
672
805
 
673
- tile_A A[ntx][WARP_SIZE/QI8_0];
674
- float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
806
+ tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];
807
+ float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];
675
808
 
676
809
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
677
810
 
678
811
  #pragma unroll
679
812
  for (int n = 0; n < ntx; ++n) {
680
813
  #pragma unroll
681
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
814
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
682
815
  const int k0 = k00 + k01;
683
816
 
684
817
  load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
@@ -689,7 +822,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
689
822
  const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
690
823
 
691
824
  #pragma unroll
692
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
825
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
693
826
  const int k0 = k00 + k01;
694
827
 
695
828
  dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
@@ -700,7 +833,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
700
833
  #pragma unroll
701
834
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
702
835
  #pragma unroll
703
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
836
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
704
837
  tile_B B;
705
838
  float dB[tile_C::ne/2];
706
839
 
@@ -729,11 +862,14 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
729
862
  }
730
863
  }
731
864
  }
865
+ #endif // defined(AMD_MFMA_AVAILABLE)
732
866
  }
733
867
 
734
- template <int mmq_x, int mmq_y, int nwarps>
868
+ template <int mmq_x, int mmq_y>
735
869
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
736
870
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
871
+ constexpr int nwarps = mmq_get_nwarps_device();
872
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
737
873
 
738
874
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
739
875
  const int * x_qs = (const int *) x;
@@ -742,7 +878,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
742
878
  const half2 * y_ds = (const half2 *) y;
743
879
 
744
880
  // #pragma unroll
745
- for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
881
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
746
882
  const int k0 = k00 + k01;
747
883
 
748
884
  #pragma unroll
@@ -750,45 +886,95 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
750
886
  const int j = j0 + threadIdx.y;
751
887
 
752
888
  #pragma unroll
753
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
889
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
754
890
  const int i = i0 + threadIdx.x;
755
891
 
756
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
757
- (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
758
- x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
892
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
893
+ (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
894
+ x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
759
895
  }
760
896
  }
761
897
  }
762
898
  }
763
899
 
764
- template <int mmq_x, int mmq_y, int nwarps>
900
+ template <int mmq_x, int mmq_y>
765
901
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
766
902
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
903
+ #if defined(AMD_MFMA_AVAILABLE)
904
+ typedef tile<16, 8, int> tile_A;
905
+ typedef tile<16, 8, int> tile_B;
906
+ typedef tile<16, 16, int> tile_C;
767
907
 
768
- typedef tile<16, 8, int> tile_A;
769
- typedef tile< 8, 8, int> tile_B;
770
- typedef tile<16, 8, int> tile_C;
908
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
909
+ constexpr int rows_per_warp = granularity;
910
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
911
+
912
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
913
+
914
+ const int * x_qs = (const int *) x;
915
+ const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
916
+ const int * y_qs = (const int *) y + 4;
917
+ const half2 * y_dm = (const half2 *) y;
918
+
919
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
920
+
921
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
922
+ const int k0 = k00 + k01;
923
+
924
+ tile_A A[ntx];
925
+ #pragma unroll
926
+ for (int n = 0; n < ntx; ++n) {
927
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
928
+ }
929
+
930
+ #pragma unroll
931
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
932
+ tile_B B;
933
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
934
+
935
+ const int j = j0 + tile_C::get_j(0);
936
+ const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
937
+
938
+ #pragma unroll
939
+ for (int n = 0; n < ntx; ++n) {
940
+ tile_C C;
941
+ mma(C, A[n], B);
942
+
943
+ #pragma unroll
944
+ for (int l = 0; l < tile_C::ne; ++l) {
945
+ const int i = i0 + n*tile_A::I + tile_C::get_i(l);
946
+ float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
947
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
948
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;
949
+ }
950
+ }
951
+ }
952
+ }
953
+ #else
954
+ typedef tile<16, 8, int> tile_A;
955
+ typedef tile< 8, 8, int> tile_B;
956
+ typedef tile<16, 8, int> tile_C;
771
957
 
772
958
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
773
959
  constexpr int rows_per_warp = 2 * granularity;
774
960
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
775
961
 
776
- y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
962
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
777
963
 
778
964
  const int * x_qs = (const int *) x;
779
- const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
965
+ const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
780
966
  const int * y_qs = (const int *) y + 4;
781
967
  const half2 * y_dm = (const half2 *) y;
782
968
 
783
- tile_A A[ntx][WARP_SIZE/QI8_1];
784
- float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
969
+ tile_A A[ntx][MMQ_TILE_NE_K/QI8_1];
970
+ float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
785
971
 
786
972
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
787
973
 
788
974
  #pragma unroll
789
975
  for (int n = 0; n < ntx; ++n) {
790
976
  #pragma unroll
791
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
977
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
792
978
  const int k0 = k00 + k01;
793
979
 
794
980
  load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
@@ -799,7 +985,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
799
985
  const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
800
986
 
801
987
  #pragma unroll
802
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
988
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
803
989
  const int k0 = k00 + k01;
804
990
 
805
991
  dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
@@ -810,7 +996,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
810
996
  #pragma unroll
811
997
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
812
998
  #pragma unroll
813
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
999
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
814
1000
  tile_B B;
815
1001
  float2 dsB[tile_C::ne/2];
816
1002
 
@@ -836,11 +1022,15 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
836
1022
  }
837
1023
  }
838
1024
  }
1025
+ #endif // defined(AMD_MFMA_AVAILABLE)
839
1026
  }
840
1027
 
841
- template <int mmq_x, int mmq_y, int nwarps>
1028
+ // Used for Q3_K, IQ2_S, and IQ2_XS
1029
+ template <int mmq_x, int mmq_y>
842
1030
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
843
1031
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1032
+ constexpr int nwarps = mmq_get_nwarps_device();
1033
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
844
1034
 
845
1035
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
846
1036
  const int * x_qs = (const int *) x;
@@ -849,7 +1039,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
849
1039
  const float * y_df = (const float *) y;
850
1040
 
851
1041
  // #pragma unroll
852
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
1042
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
853
1043
  const int k0 = k00 + k01;
854
1044
 
855
1045
  #pragma unroll
@@ -857,23 +1047,73 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
857
1047
  const int j = j0 + threadIdx.y;
858
1048
 
859
1049
  #pragma unroll
860
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1050
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
861
1051
  const int i = i0 + threadIdx.x;
862
1052
 
863
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
864
- &x_qs[i*(2*WARP_SIZE + 1) + k0],
1053
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
1054
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0],
865
1055
  &y_qs[j*MMQ_TILE_Y_K + k01],
866
- &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
1056
+ &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
867
1057
  y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
868
1058
  }
869
1059
  }
870
1060
  }
871
1061
  }
872
1062
 
873
- template <int mmq_x, int mmq_y, int nwarps>
1063
+ // Used for Q3_K, IQ2_S, and IQ2_XS:
1064
+ template <int mmq_x, int mmq_y>
874
1065
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
875
1066
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
876
- #ifdef NEW_MMA_AVAILABLE
1067
+ #if defined(AMD_MFMA_AVAILABLE)
1068
+ typedef tile<16, 8, int> tile_A;
1069
+ typedef tile<16, 8, int> tile_B;
1070
+ typedef tile<16, 16, int> tile_C;
1071
+ typedef tile<64, 2, int> tile_load;
1072
+
1073
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1074
+ constexpr int rows_per_warp = granularity;
1075
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1076
+
1077
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1078
+
1079
+ const int * x_qs = (const int *) x;
1080
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1081
+ const int * y_qs = (const int *) y + 4;
1082
+ const float * y_df = (const float *) y;
1083
+
1084
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1085
+
1086
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1087
+ const int k0 = k00 + k01;
1088
+
1089
+ tile_A A[ntx];
1090
+ #pragma unroll
1091
+ for (int n = 0; n < ntx; ++n) {
1092
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1093
+ }
1094
+
1095
+ #pragma unroll
1096
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1097
+ tile_B B[1];
1098
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1099
+
1100
+ const int j = j0 + tile_C::get_j(0);
1101
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
1102
+
1103
+ #pragma unroll
1104
+ for (int n = 0; n < ntx; ++n) {
1105
+ tile_C C;
1106
+ mma(C, A[n], B[0]);
1107
+
1108
+ #pragma unroll
1109
+ for (int l = 0; l < tile_C::ne; ++l) {
1110
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1111
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
1112
+ }
1113
+ }
1114
+ }
1115
+ }
1116
+ #elif defined(NEW_MMA_AVAILABLE)
877
1117
 
878
1118
  typedef tile<16, 4, int> tile_A;
879
1119
  typedef tile<16, 8, int> tile_A_8;
@@ -884,10 +1124,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
884
1124
  constexpr int rows_per_warp = 2 * granularity;
885
1125
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
886
1126
 
887
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
1127
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
888
1128
 
889
1129
  const int * x_qs = (const int *) x;
890
- const float * x_df = (const float *) x_qs + WARP_SIZE*2;
1130
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
891
1131
  const int * y_qs = (const int *) y + 4;
892
1132
  const float * y_df = (const float *) y;
893
1133
 
@@ -899,7 +1139,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
899
1139
  #pragma unroll
900
1140
  for (int n = 0; n < ntx; ++n) {
901
1141
  #pragma unroll
902
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1142
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
903
1143
  const int k0 = k00 + k01;
904
1144
 
905
1145
  load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
@@ -910,7 +1150,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
910
1150
  const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
911
1151
 
912
1152
  #pragma unroll
913
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
1153
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
914
1154
  const int k0 = k00 + k01;
915
1155
 
916
1156
  dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
@@ -921,7 +1161,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
921
1161
  #pragma unroll
922
1162
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
923
1163
  #pragma unroll
924
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1164
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
925
1165
  tile_B B[2];
926
1166
  float dB[tile_C::ne/2];
927
1167
 
@@ -952,26 +1192,29 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
952
1192
  #else
953
1193
  GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
954
1194
  NO_DEVICE_CODE;
955
- #endif // NEW_MMA_AVAILABLE
1195
+ #endif // AMD_MFMA_AVAILABLE
956
1196
  }
957
1197
 
958
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
1198
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
959
1199
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1200
+ constexpr int nwarps = mmq_get_nwarps_device();
960
1201
 
961
- #ifdef NEW_MMA_AVAILABLE
1202
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
962
1203
  int * x_qs = (int *) x_tile;
963
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
1204
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
964
1205
  #else
965
1206
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
966
1207
  int * x_qs = (int *) x_tile;
967
1208
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
968
- #endif // NEW_MMA_AVAILABLE
1209
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
969
1210
 
970
- const int kqsx = threadIdx.x % QI2_K;
1211
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
1212
+ constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
1213
+ const int kqsx = threadIdx.x % threads_per_row;
971
1214
 
972
1215
  #pragma unroll
973
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) {
974
- int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K;
1216
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1217
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
975
1218
 
976
1219
  if (need_check) {
977
1220
  i = min(i, i_max);
@@ -987,11 +1230,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
987
1230
 
988
1231
  const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
989
1232
 
990
- #ifdef NEW_MMA_AVAILABLE
1233
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
991
1234
  x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
992
1235
  #else
993
- x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
994
- #endif // NEW_MMA_AVAILABLE
1236
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1237
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
995
1238
  }
996
1239
 
997
1240
  const int sc_m = bxi->scales[kqsx];
@@ -1002,17 +1245,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1002
1245
  const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
1003
1246
  #endif // FAST_FP16_AVAILABLE
1004
1247
 
1005
- #ifdef NEW_MMA_AVAILABLE
1248
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1006
1249
  x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
1007
1250
  #else
1008
- x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik;
1009
- #endif // NEW_MMA_AVAILABLE
1251
+ x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
1252
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1010
1253
  }
1011
1254
  }
1012
1255
 
1013
- template <int mmq_x, int mmq_y, int nwarps>
1256
+ template <int mmq_x, int mmq_y>
1014
1257
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1015
1258
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1259
+ constexpr int nwarps = mmq_get_nwarps_device();
1260
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1016
1261
 
1017
1262
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1018
1263
  const int * x_qs = (const int *) x;
@@ -1029,7 +1274,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1029
1274
  }
1030
1275
 
1031
1276
  #pragma unroll
1032
- for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1277
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1033
1278
  const int k0 = k00 + k01;
1034
1279
 
1035
1280
  #pragma unroll
@@ -1037,13 +1282,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1037
1282
  const int j = j0 + threadIdx.y;
1038
1283
 
1039
1284
  #pragma unroll
1040
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1285
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1041
1286
  const int i = i0 + threadIdx.x;
1042
1287
 
1043
1288
  constexpr int ns = 2;
1044
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1045
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1046
- &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1289
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1290
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1291
+ &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1047
1292
  &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1048
1293
  }
1049
1294
  }
@@ -1052,7 +1297,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1052
1297
  // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
1053
1298
  // As a workaround 2 separate loops are used instead.
1054
1299
  #pragma unroll
1055
- for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1300
+ for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1056
1301
  const int k0 = k00 + k01;
1057
1302
 
1058
1303
  #pragma unroll
@@ -1060,23 +1305,89 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1060
1305
  const int j = j0 + threadIdx.y;
1061
1306
 
1062
1307
  #pragma unroll
1063
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1308
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1064
1309
  const int i = i0 + threadIdx.x;
1065
1310
 
1066
1311
  constexpr int ns = 1;
1067
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1068
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1069
- &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1312
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1313
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1314
+ &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1070
1315
  &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1071
1316
  }
1072
1317
  }
1073
1318
  }
1074
1319
  }
1075
1320
 
1076
- template <int mmq_x, int mmq_y, int nwarps>
1321
+ template <int mmq_x, int mmq_y>
1077
1322
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1078
1323
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1079
- #ifdef NEW_MMA_AVAILABLE
1324
+ #if defined(AMD_MFMA_AVAILABLE)
1325
+ typedef tile<16, 8, int> tile_A;
1326
+ typedef tile<16, 8, int> tile_B;
1327
+ typedef tile<16, 16, int> tile_C;
1328
+ typedef tile<64, 2, int> tile_load;
1329
+
1330
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1331
+ constexpr int rows_per_warp = granularity;
1332
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1333
+
1334
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1335
+
1336
+ const int * x_qs = (const int *) x;
1337
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1338
+ const int * y_qs = (const int *) y + 4;
1339
+ const half2 * y_ds = (const half2 *) y;
1340
+
1341
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1342
+
1343
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1344
+ const int k0 = k00 + k01;
1345
+
1346
+ tile_A A[ntx];
1347
+ #pragma unroll
1348
+ for (int n = 0; n < ntx; ++n) {
1349
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1350
+ }
1351
+
1352
+ #pragma unroll
1353
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1354
+ tile_B B[1];
1355
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1356
+
1357
+ const int j = j0 + tile_C::get_j(0);
1358
+ const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
1359
+ const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
1360
+ : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1361
+ : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
1362
+
1363
+ tile_C Cm;
1364
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1365
+ tile_A A1;
1366
+ A1.x[0] = 0x01010101;
1367
+ A1.x[1] = 0x01010101;
1368
+ mma(Cm, A1, B[0]);
1369
+ }
1370
+
1371
+ #pragma unroll
1372
+ for (int n = 0; n < ntx; ++n) {
1373
+ tile_C Cd;
1374
+ mma(Cd, A[n], B[0]);
1375
+
1376
+ #pragma unroll
1377
+ for (int l = 0; l < tile_C::ne; ++l) {
1378
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1379
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
1380
+ float tmp = Cd.x[l]*dm.x;
1381
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1382
+ tmp -= Cm.x[l]*dm.y;
1383
+ }
1384
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
1385
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
1386
+ }
1387
+ }
1388
+ }
1389
+ }
1390
+ #elif defined(NEW_MMA_AVAILABLE)
1080
1391
 
1081
1392
  typedef tile<16, 4, int> tile_A;
1082
1393
  typedef tile<16, 8, int> tile_A_8;
@@ -1087,10 +1398,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1087
1398
  constexpr int rows_per_warp = 2 * granularity;
1088
1399
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1089
1400
 
1090
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
1401
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1091
1402
 
1092
1403
  const int * x_qs = (const int *) x;
1093
- const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
1404
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1094
1405
  const int * y_qs = (const int *) y + 4;
1095
1406
  const half2 * y_ds = (const half2 *) y;
1096
1407
 
@@ -1103,7 +1414,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1103
1414
  #pragma unroll
1104
1415
  for (int n = 0; n < ntx; ++n) {
1105
1416
  #pragma unroll
1106
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1417
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1107
1418
  const int k0 = k00 + k01;
1108
1419
 
1109
1420
  load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
@@ -1117,7 +1428,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1117
1428
  const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
1118
1429
 
1119
1430
  #pragma unroll
1120
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
1431
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) {
1121
1432
  const int k0 = k00 + k01;
1122
1433
 
1123
1434
  const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
@@ -1140,7 +1451,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1140
1451
  }
1141
1452
 
1142
1453
  #pragma unroll
1143
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1454
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1144
1455
  tile_B B[2];
1145
1456
 
1146
1457
  // Here load_generic is faster than load_ldmatrix.
@@ -1148,7 +1459,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1148
1459
  load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
1149
1460
 
1150
1461
  tile_C Cm[2];
1151
- if (k01 >= WARP_SIZE * 3/4) {
1462
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1152
1463
  tile_A A1;
1153
1464
  A1.x[0] = 0x01010101;
1154
1465
  A1.x[1] = 0x01010101;
@@ -1166,16 +1477,16 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1166
1477
  #pragma unroll
1167
1478
  for (int l = 0; l < tile_C::ne; ++l) {
1168
1479
  float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
1169
- if (k01 >= WARP_SIZE * 3/4) {
1480
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
1170
1481
  tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
1171
1482
  }
1172
- sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
1483
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y);
1173
1484
  }
1174
1485
  }
1175
1486
  }
1176
1487
 
1177
1488
  #pragma unroll
1178
- for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
1489
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) {
1179
1490
  float2 sB[tile_C::ne/2];
1180
1491
 
1181
1492
  #pragma unroll
@@ -1198,27 +1509,31 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1198
1509
  #else
1199
1510
  GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
1200
1511
  NO_DEVICE_CODE;
1201
- #endif // NEW_MMA_AVAILABLE
1512
+ #endif // AMD_MFMA_AVAILABLE
1202
1513
  }
1203
1514
 
1204
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
1515
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
1205
1516
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1517
+ constexpr int nwarps = mmq_get_nwarps_device();
1518
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1206
1519
 
1207
- #ifdef NEW_MMA_AVAILABLE
1520
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1208
1521
  int * x_qs = (int *) x_tile;
1209
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
1522
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1210
1523
  #else
1211
1524
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1212
1525
  int * x_qs = (int *) x_tile;
1213
1526
  float * x_df = (float *) (x_qs + txs.qs);
1214
1527
  int * x_sc = (int *) (x_df + txs.dm);
1215
- #endif // NEW_MMA_AVAILABLE
1528
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1216
1529
 
1217
- const int kqsx = threadIdx.x % QI3_K;
1530
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
1531
+ constexpr int nrows = warp_size / threads_per_row;
1532
+ const int kqsx = threadIdx.x % threads_per_row;
1218
1533
 
1219
1534
  #pragma unroll
1220
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) {
1221
- int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K;
1535
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1536
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1222
1537
 
1223
1538
  if (need_check) {
1224
1539
  i = min(i, i_max);
@@ -1238,17 +1553,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1238
1553
 
1239
1554
  const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
1240
1555
 
1241
- #ifdef NEW_MMA_AVAILABLE
1556
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1242
1557
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
1243
1558
  #else
1244
- x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
1245
- #endif // NEW_MMA_AVAILABLE
1559
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1560
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1246
1561
  }
1247
1562
  }
1248
1563
 
1564
+ constexpr int rows_per_warp = warp_size / 4;
1249
1565
  #pragma unroll
1250
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
1251
- int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
1566
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1567
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
1252
1568
 
1253
1569
  if (need_check) {
1254
1570
  i = min(i, i_max);
@@ -1256,7 +1572,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1256
1572
 
1257
1573
  const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1258
1574
 
1259
- const int ksc = threadIdx.x % (WARP_SIZE/8);
1575
+ const int ksc = threadIdx.x % 4;
1260
1576
 
1261
1577
  const int ksc_low = ksc % (QI3_K/8);
1262
1578
  const int shift_low = 4 * (ksc / (QI3_K/8));
@@ -1268,23 +1584,23 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1268
1584
 
1269
1585
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1270
1586
 
1271
- #ifdef NEW_MMA_AVAILABLE
1587
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1272
1588
  const int8_t * sc8 = (const int8_t *) &sc;
1273
1589
  const float d = bxi->d;
1274
1590
 
1275
1591
  #pragma unroll
1276
1592
  for (int l = 0; l < int(sizeof(int)); ++l) {
1277
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
1593
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
1278
1594
  }
1279
1595
  #else
1280
- x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
1281
- #endif // NEW_MMA_AVAILABLE
1596
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
1597
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1282
1598
  }
1283
1599
 
1284
- #ifndef NEW_MMA_AVAILABLE
1600
+ #if !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE))
1285
1601
  #pragma unroll
1286
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
1287
- int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
1602
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1603
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1288
1604
 
1289
1605
  if (need_check) {
1290
1606
  i = min(i, i_max);
@@ -1294,12 +1610,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1294
1610
 
1295
1611
  x_df[i] = bxi->d;
1296
1612
  }
1297
- #endif // NEW_MMA_AVAILABLE
1613
+ #endif // !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE))
1298
1614
  }
1299
1615
 
1300
- template <int mmq_x, int mmq_y, int nwarps>
1616
+ template <int mmq_x, int mmq_y>
1301
1617
  static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1302
1618
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1619
+ constexpr int nwarps = mmq_get_nwarps_device();
1620
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1303
1621
 
1304
1622
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1305
1623
  const int * x_qs = (const int *) x;
@@ -1309,7 +1627,7 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1309
1627
  const float * y_df = (const float *) y;
1310
1628
 
1311
1629
  // #pragma unroll
1312
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1630
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1313
1631
  const int k0 = k00 + k01;
1314
1632
 
1315
1633
  #pragma unroll
@@ -1317,13 +1635,13 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1317
1635
  const int j = j0 + threadIdx.y;
1318
1636
 
1319
1637
  #pragma unroll
1320
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1638
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1321
1639
  const int i = i0 + threadIdx.x;
1322
1640
 
1323
- const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4;
1641
+ const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4;
1324
1642
 
1325
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
1326
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
1643
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq(
1644
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
1327
1645
  x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1328
1646
  }
1329
1647
  }
@@ -1340,72 +1658,85 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
1340
1658
  ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
1341
1659
  }
1342
1660
 
1343
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1661
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1344
1662
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1663
+ constexpr int nwarps = mmq_get_nwarps_device();
1664
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1345
1665
 
1346
- #ifdef NEW_MMA_AVAILABLE
1666
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1347
1667
  int * x_qs = (int *) x_tile;
1348
- half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
1668
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1349
1669
  #else
1350
1670
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1351
1671
  int * x_qs = (int *) x_tile;
1352
1672
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1353
1673
  int * x_sc = (int *) (x_dm + txs.dm);
1354
- #endif // NEW_MMA_AVAILABLE
1674
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1675
+
1676
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
1677
+ constexpr int nrows = warp_size / threads_per_row;
1678
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1355
1679
 
1356
1680
  #pragma unroll
1357
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1358
- int i = i0 + threadIdx.y;
1681
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1682
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1359
1683
 
1360
1684
  if (need_check) {
1361
1685
  i = min(i, i_max);
1362
1686
  }
1363
1687
 
1364
1688
  const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1365
- const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
1689
+ const int qs0 = get_int_b4(bxi->qs, txi);
1366
1690
 
1367
- #ifdef NEW_MMA_AVAILABLE
1368
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
1369
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
1691
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1692
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
1693
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
1370
1694
  #else
1371
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
1372
- #endif // NEW_MMA_AVAILABLE
1695
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
1696
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1373
1697
  }
1374
1698
 
1375
- #ifdef NEW_MMA_AVAILABLE
1376
-
1699
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1700
+ constexpr int rows_per_warp = warp_size / 2;
1377
1701
  #pragma unroll
1378
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
1379
- int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
1380
-
1381
- if (need_check) {
1382
- i = min(i, i_max);
1383
- }
1702
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1703
+ #if defined(AMD_MFMA_AVAILABLE)
1704
+ // Need if on AMD instead of % because warp_size == 64
1705
+ // This causes double work and throughput loss (MI300X)
1706
+ // H100 loses about 100 t/s with 'if' condition over '%'
1707
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
1708
+ if (i < mmq_y) {
1709
+ #else
1710
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
1711
+ {
1712
+ #endif // defined(AMD_MFMA_AVAILABLE)
1713
+ if (need_check) {
1714
+ i = min(i, i_max);
1715
+ }
1384
1716
 
1385
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1717
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1386
1718
 
1387
- const int * scales = (const int *) bxi->scales;
1388
- const int ksc = threadIdx.x % (WARP_SIZE/16);
1719
+ const int * scales = (const int *) bxi->scales;
1720
+ const int ksc = threadIdx.x % 2;
1389
1721
 
1390
- const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1391
- const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1722
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1723
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1392
1724
 
1393
- const uint8_t * sc8 = (const uint8_t *) &sc32;
1394
- const uint8_t * m8 = (const uint8_t *) &m32;
1725
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
1726
+ const uint8_t * m8 = (const uint8_t *) &m32;
1395
1727
 
1396
- const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1728
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1397
1729
 
1398
- #pragma unroll
1399
- for (int l = 0; l < int(sizeof(int)); ++l) {
1400
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
1730
+ #pragma unroll
1731
+ for (int l = 0; l < sizeof(int); ++l) {
1732
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
1733
+ }
1401
1734
  }
1402
1735
  }
1403
-
1404
1736
  #else
1405
-
1406
1737
  #pragma unroll
1407
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) {
1408
- int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y;
1738
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1739
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1409
1740
 
1410
1741
  if (need_check) {
1411
1742
  i = min(i, i_max);
@@ -1415,30 +1746,32 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1415
1746
 
1416
1747
  x_dm[i] = bxi->dm;
1417
1748
  }
1418
-
1749
+ constexpr int rows_per_warp = warp_size / 4;
1419
1750
  #pragma unroll
1420
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
1421
- int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
1751
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1752
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
1422
1753
 
1423
1754
  if (need_check) {
1424
1755
  i = min(i, i_max);
1425
1756
  }
1426
1757
 
1427
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
1758
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8);
1428
1759
 
1429
1760
  const int * scales = (const int *) bxi->scales;
1430
1761
 
1431
- const int ksc = threadIdx.x % (WARP_SIZE/8);
1762
+ const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
1432
1763
  const int scales8 = unpack_scales_q45_K(scales, ksc);
1433
1764
 
1434
- x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
1765
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1435
1766
  }
1436
- #endif // NEW_MMA_AVAILABLE
1767
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1437
1768
  }
1438
1769
 
1439
- template <int mmq_x, int mmq_y, int nwarps>
1770
+ template <int mmq_x, int mmq_y>
1440
1771
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1441
1772
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1773
+ constexpr int nwarps = mmq_get_nwarps_device();
1774
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1442
1775
 
1443
1776
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1444
1777
  const int * x_qs = (const int *) x;
@@ -1448,7 +1781,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1448
1781
  const half2 * y_ds = (const half2 *) y;
1449
1782
 
1450
1783
  // #pragma unroll
1451
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
1784
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
1452
1785
  const int k0 = k00 + k01;
1453
1786
 
1454
1787
  #pragma unroll
@@ -1456,97 +1789,110 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1456
1789
  const int j = j0 + threadIdx.y;
1457
1790
 
1458
1791
  #pragma unroll
1459
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1792
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1460
1793
  const int i = i0 + threadIdx.x;
1461
1794
 
1462
- const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16);
1795
+ const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16);
1463
1796
 
1464
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
1465
- &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1797
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq(
1798
+ &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1466
1799
  x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1467
1800
  }
1468
1801
  }
1469
1802
  }
1470
1803
  }
1471
1804
 
1472
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1805
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1473
1806
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1807
+ constexpr int nwarps = mmq_get_nwarps_device();
1808
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1474
1809
 
1475
- #ifdef NEW_MMA_AVAILABLE
1810
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1476
1811
  int * x_qs = (int *) x_tile;
1477
- half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
1812
+ half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
1478
1813
  #else
1479
1814
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1480
1815
  int * x_qs = (int *) x_tile;
1481
1816
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1482
1817
  int * x_sc = (int *) (x_dm + txs.dm);
1483
- #endif // NEW_MMA_AVAILABLE
1818
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1819
+
1820
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
1821
+ constexpr int nrows = warp_size / threads_per_row;
1822
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1484
1823
 
1485
1824
  #pragma unroll
1486
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1487
- int i = i0 + threadIdx.y;
1825
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1826
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1488
1827
 
1489
1828
  if (need_check) {
1490
1829
  i = min(i, i_max);
1491
1830
  }
1492
1831
 
1493
1832
  const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1494
- const int ky = QR5_K*threadIdx.x;
1833
+ const int ky = QR5_K*txi;
1495
1834
 
1496
- const int ql = get_int_b4(bxi->qs, threadIdx.x);
1835
+ const int ql = get_int_b4(bxi->qs, txi);
1497
1836
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1498
1837
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1499
1838
 
1500
- const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4));
1501
- const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010;
1502
- const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010;
1839
+ const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4));
1840
+ const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010;
1841
+ const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010;
1503
1842
 
1504
- const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
1505
- const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
1843
+ const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
1844
+ const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
1506
1845
 
1507
- #ifdef NEW_MMA_AVAILABLE
1846
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1508
1847
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
1509
1848
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
1510
1849
  #else
1511
- x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
1512
- x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
1513
- #endif // NEW_MMA_AVAILABLE
1850
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
1851
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
1852
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1514
1853
  }
1515
1854
 
1516
- #ifdef NEW_MMA_AVAILABLE
1517
-
1855
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1856
+ constexpr int rows_per_warp = warp_size / 2;
1518
1857
  #pragma unroll
1519
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
1520
- int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
1521
-
1522
- if (need_check) {
1523
- i = min(i, i_max);
1524
- }
1858
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1859
+ #if defined(AMD_MFMA_AVAILABLE)
1860
+ // Need if on AMD instead of % because warp_size == 64
1861
+ // This causes double work and throughput loss (MI300X)
1862
+ // H100 loses about 100 t/s with 'if' condition over '%'
1863
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
1864
+ if (i < mmq_y) {
1865
+ #else
1866
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
1867
+ {
1868
+ #endif // defined(AMD_MFMA_AVAILABLE)
1869
+ if (need_check) {
1870
+ i = min(i, i_max);
1871
+ }
1525
1872
 
1526
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1873
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1527
1874
 
1528
- const int * scales = (const int *) bxi->scales;
1529
- const int ksc = threadIdx.x % (WARP_SIZE/16);
1875
+ const int * scales = (const int *) bxi->scales;
1876
+ const int ksc = threadIdx.x % 2;
1530
1877
 
1531
- const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1532
- const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1878
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1879
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1533
1880
 
1534
- const uint8_t * sc8 = (const uint8_t *) &sc32;
1535
- const uint8_t * m8 = (const uint8_t *) &m32;
1881
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
1882
+ const uint8_t * m8 = (const uint8_t *) &m32;
1536
1883
 
1537
- const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1884
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1538
1885
 
1539
1886
  #pragma unroll
1540
- for (int l = 0; l < int(sizeof(int)); ++l) {
1541
- x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
1887
+ for (int l = 0; l < int(sizeof(int)); ++l) {
1888
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
1889
+ }
1542
1890
  }
1543
1891
  }
1544
-
1545
1892
  #else
1546
-
1547
1893
  #pragma unroll
1548
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) {
1549
- int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y;
1894
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1895
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1550
1896
 
1551
1897
  if (need_check) {
1552
1898
  i = min(i, i_max);
@@ -1557,9 +1903,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1557
1903
  x_dm[i] = bxi->dm;
1558
1904
  }
1559
1905
 
1906
+ constexpr int rows_per_warp = warp_size / 4;
1560
1907
  #pragma unroll
1561
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
1562
- int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y;
1908
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1909
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
1563
1910
 
1564
1911
  if (need_check) {
1565
1912
  i = min(i, i_max);
@@ -1569,17 +1916,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1569
1916
 
1570
1917
  const int * scales = (const int *) bxi->scales;
1571
1918
 
1572
- const int ksc = threadIdx.x % (WARP_SIZE/8);
1919
+ const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
1573
1920
  const int scales8 = unpack_scales_q45_K(scales, ksc);
1574
1921
 
1575
- x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
1922
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1576
1923
  }
1577
- #endif // NEW_MMA_AVAILABLE
1924
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1578
1925
  }
1579
1926
 
1580
- template <int mmq_x, int mmq_y, int nwarps>
1927
+ template <int mmq_x, int mmq_y>
1581
1928
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1582
1929
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1930
+ constexpr int nwarps = mmq_get_nwarps_device();
1931
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1583
1932
 
1584
1933
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1585
1934
  const int * x_qs = (const int *) x;
@@ -1589,7 +1938,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1589
1938
  const half2 * y_ds = (const half2 *) y;
1590
1939
 
1591
1940
  // #pragma unroll
1592
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
1941
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
1593
1942
  const int k0 = k00 + k01;
1594
1943
 
1595
1944
  #pragma unroll
@@ -1597,36 +1946,42 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1597
1946
  const int j = j0 + threadIdx.y;
1598
1947
 
1599
1948
  #pragma unroll
1600
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1949
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1601
1950
  const int i = i0 + threadIdx.x;
1602
1951
 
1603
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
1952
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16);
1604
1953
 
1605
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
1606
- &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1954
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq(
1955
+ &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1607
1956
  x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1608
1957
  }
1609
1958
  }
1610
1959
  }
1611
1960
  }
1612
1961
 
1613
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
1962
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
1614
1963
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1964
+ constexpr int nwarps = mmq_get_nwarps_device();
1965
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1615
1966
 
1616
- #ifdef NEW_MMA_AVAILABLE
1967
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1617
1968
  int * x_qs = (int *) x_tile;
1618
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
1619
- int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
1969
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1970
+ int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
1620
1971
  #else
1621
1972
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
1622
1973
  int * x_qs = (int *) x_tile;
1623
1974
  float * x_df = (float *) (x_qs + txs.qs);
1624
1975
  int * x_sc = (int *) (x_df + txs.dm);
1625
- #endif // NEW_MMA_AVAILABLE
1976
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1977
+
1978
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
1979
+ constexpr int nrows = warp_size / threads_per_row;
1980
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1626
1981
 
1627
1982
  #pragma unroll
1628
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1629
- int i = i0 + threadIdx.y;
1983
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1984
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1630
1985
 
1631
1986
  if (need_check) {
1632
1987
  i = min(i, i_max);
@@ -1634,67 +1989,67 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1634
1989
 
1635
1990
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
1636
1991
 
1637
- const int ql = get_int_b2(bxi->ql, threadIdx.x);
1992
+ const int ql = get_int_b2(bxi->ql, txi);
1638
1993
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1639
1994
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1640
1995
 
1641
- const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4));
1642
- const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030;
1643
- const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030;
1996
+ const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4));
1997
+ const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030;
1998
+ const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030;
1644
1999
 
1645
- const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
1646
- const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
2000
+ const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
2001
+ const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
1647
2002
 
1648
- #ifdef NEW_MMA_AVAILABLE
2003
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1649
2004
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
1650
2005
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
1651
2006
  #else
1652
- x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
1653
- x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
1654
- #endif // NEW_MMA_AVAILABLE
2007
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2008
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2009
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1655
2010
  }
1656
2011
 
1657
- const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
1658
- const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
1659
-
1660
2012
  #pragma unroll
1661
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
1662
- int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
2013
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
2014
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1663
2015
 
1664
2016
  if (need_check) {
1665
2017
  i = min(i, i_max);
1666
2018
  }
1667
2019
 
1668
- const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
2020
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
1669
2021
 
1670
- #ifdef NEW_MMA_AVAILABLE
1671
- x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d;
2022
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2023
+ x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
1672
2024
  #else
1673
- x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
1674
- #endif // NEW_MMA_AVAILABLE
2025
+ x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
2026
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1675
2027
  }
1676
2028
 
2029
+ constexpr int rows_per_warp = warp_size / 4;
1677
2030
  #pragma unroll
1678
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
1679
- int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
2031
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2032
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
1680
2033
 
1681
2034
  if (need_check) {
1682
2035
  i = min(i, i_max);
1683
2036
  }
1684
2037
 
1685
- const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
2038
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
1686
2039
 
1687
- #ifdef NEW_MMA_AVAILABLE
1688
- x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
2040
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2041
+ x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
1689
2042
  #else
1690
- x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
1691
- #endif // NEW_MMA_AVAILABLE
2043
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
2044
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1692
2045
  }
1693
2046
  }
1694
2047
 
1695
- template <int mmq_x, int mmq_y, int nwarps>
2048
+ template <int mmq_x, int mmq_y>
1696
2049
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1697
2050
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2051
+ constexpr int nwarps = mmq_get_nwarps_device();
2052
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1698
2053
 
1699
2054
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
1700
2055
  const int * x_qs = (const int *) x;
@@ -1704,7 +2059,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1704
2059
  const float * y_df = (const float *) y;
1705
2060
 
1706
2061
  // #pragma unroll
1707
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
2062
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
1708
2063
  const int k0 = k00 + k01;
1709
2064
 
1710
2065
  #pragma unroll
@@ -1712,23 +2067,74 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1712
2067
  const int j = j0 + threadIdx.y;
1713
2068
 
1714
2069
  #pragma unroll
1715
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
2070
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1716
2071
  const int i = i0 + threadIdx.x;
1717
2072
 
1718
- const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]);
2073
+ const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]);
1719
2074
 
1720
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
1721
- &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
1722
- x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
2075
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq(
2076
+ &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
2077
+ x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1723
2078
  }
1724
2079
  }
1725
2080
  }
1726
2081
  }
1727
2082
 
1728
- template <int mmq_x, int mmq_y, int nwarps>
2083
+ template <int mmq_x, int mmq_y>
1729
2084
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1730
2085
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1731
- #ifdef NEW_MMA_AVAILABLE
2086
+ #if defined(AMD_MFMA_AVAILABLE)
2087
+ typedef tile<16, 8, int> tile_A;
2088
+ typedef tile<16, 8, int> tile_B;
2089
+ typedef tile<16, 16, int> tile_C;
2090
+ typedef tile<64, 2, int> tile_load;
2091
+
2092
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
2093
+ constexpr int rows_per_warp = granularity;
2094
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2095
+
2096
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2097
+
2098
+ const int * x_qs = (const int *) x;
2099
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2100
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2101
+ const int * y_qs = (const int *) y + 4;
2102
+ const float * y_df = (const float *) y;
2103
+
2104
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
2105
+
2106
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2107
+ const int k0 = k00 + k01;
2108
+
2109
+ tile_A A[ntx];
2110
+ #pragma unroll
2111
+ for (int n = 0; n < ntx; ++n) {
2112
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2113
+ }
2114
+
2115
+ #pragma unroll
2116
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2117
+ tile_B B[1];
2118
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2119
+
2120
+ const int j = j0 + tile_C::get_j(0);
2121
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
2122
+
2123
+ #pragma unroll
2124
+ for (int n = 0; n < ntx; ++n) {
2125
+ tile_C C;
2126
+ mma(C, A[n], B[0]);
2127
+
2128
+ #pragma unroll
2129
+ for (int l = 0; l < tile_C::ne; ++l) {
2130
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2131
+ const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2132
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2133
+ }
2134
+ }
2135
+ }
2136
+ }
2137
+ #elif defined(NEW_MMA_AVAILABLE)
1732
2138
 
1733
2139
  typedef tile<16, 4, int> tile_A;
1734
2140
  typedef tile< 8, 4, int> tile_B;
@@ -1738,11 +2144,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1738
2144
  constexpr int rows_per_warp = 2 * granularity;
1739
2145
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1740
2146
 
1741
- y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
2147
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1742
2148
 
1743
2149
  const int * x_qs = (const int *) x;
1744
- const float * x_df = (const float *) x_qs + WARP_SIZE*2;
1745
- const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K;
2150
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2151
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
1746
2152
  const int * y_qs = (const int *) y + 4;
1747
2153
  const float * y_df = (const float *) y;
1748
2154
 
@@ -1755,7 +2161,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1755
2161
  #pragma unroll
1756
2162
  for (int n = 0; n < ntx; ++n) {
1757
2163
  #pragma unroll
1758
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
2164
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
1759
2165
  const int k0 = k00 + k01;
1760
2166
 
1761
2167
  load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
@@ -1763,7 +2169,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1763
2169
  }
1764
2170
 
1765
2171
  #pragma unroll
1766
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
2172
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) {
1767
2173
  const int k0 = k00 + k01;
1768
2174
 
1769
2175
  #pragma unroll
@@ -1793,7 +2199,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1793
2199
  float tmp[ntx][tile_C::ne] = {{0.0f}};
1794
2200
 
1795
2201
  #pragma unroll
1796
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
2202
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
1797
2203
  tile_B B[2];
1798
2204
  float dB[tile_C::ne/2];
1799
2205
 
@@ -1832,27 +2238,32 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1832
2238
  #else
1833
2239
  GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
1834
2240
  NO_DEVICE_CODE;
1835
- #endif // NEW_MMA_AVAILABLE
2241
+ #endif // AMD_MFMA_AVAILABLE
1836
2242
  }
1837
2243
 
1838
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
2244
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
1839
2245
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2246
+ constexpr int nwarps = mmq_get_nwarps_device();
2247
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1840
2248
 
1841
- #ifdef NEW_MMA_AVAILABLE
2249
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1842
2250
  int * x_qs = (int *) x_tile;
1843
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2251
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1844
2252
  #else
1845
2253
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
1846
2254
  int * x_qs = (int *) x_tile;
1847
2255
  float * x_df = (float *) (x_qs + txs.qs);
1848
- #endif // NEW_MMA_AVAILABLE
2256
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1849
2257
 
1850
- const int kbx = threadIdx.x / QI4_NL;
1851
- const int kqsx = threadIdx.x % QI4_NL;
2258
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
2259
+ constexpr int nrows = warp_size / threads_per_row;
2260
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2261
+ const int kbx = txi / QI4_NL;
2262
+ const int kqsx = txi % QI4_NL;
1852
2263
 
1853
2264
  #pragma unroll
1854
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1855
- int i = i0 + threadIdx.y;
2265
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2266
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1856
2267
 
1857
2268
  if (need_check) {
1858
2269
  i = min(i, i_max);
@@ -1862,22 +2273,24 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1862
2273
 
1863
2274
  const int aux_q4 = get_int_b2(bxi->qs, kqsx);
1864
2275
  const int2 v = get_int_from_table_16(aux_q4);
1865
- const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
1866
- #ifdef NEW_MMA_AVAILABLE
1867
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
1868
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2276
+ const int k0 = kbx * (2 * QI4_NL) + kqsx;
2277
+
2278
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2279
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2280
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
1869
2281
  #else
1870
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
1871
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
1872
- #endif // NEW_MMA_AVAILABLE
2282
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2283
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
2284
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1873
2285
  }
1874
2286
 
1875
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
2287
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
2288
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
1876
2289
  const int kbxd = threadIdx.x % blocks_per_tile_x_row;
1877
2290
 
1878
2291
  #pragma unroll
1879
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
1880
- int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
2292
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
2293
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
1881
2294
 
1882
2295
  if (need_check) {
1883
2296
  i = min(i, i_max);
@@ -1885,31 +2298,35 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1885
2298
 
1886
2299
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
1887
2300
 
1888
- #ifdef NEW_MMA_AVAILABLE
1889
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
2301
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2302
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
1890
2303
  #else
1891
- x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
1892
- #endif // NEW_MMA_AVAILABLE
2304
+ x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
2305
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1893
2306
  }
1894
2307
  }
1895
2308
 
1896
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
2309
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
1897
2310
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2311
+ constexpr int nwarps = mmq_get_nwarps_device();
2312
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1898
2313
 
1899
- #ifdef NEW_MMA_AVAILABLE
2314
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1900
2315
  int * x_qs = (int *) x_tile;
1901
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2316
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1902
2317
  #else
1903
2318
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
1904
2319
  int * x_qs = (int *) x_tile;
1905
2320
  float * x_df = (float *) (x_qs + txs.qs);
1906
- #endif // NEW_MMA_AVAILABLE
2321
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1907
2322
 
1908
- const int kqsx = threadIdx.x % (QI2_XXS/2);
2323
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
2324
+ constexpr int nrows = warp_size / threads_per_row;
2325
+ const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1909
2326
 
1910
2327
  #pragma unroll
1911
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) {
1912
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2);
2328
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2329
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1913
2330
 
1914
2331
  if (need_check) {
1915
2332
  i = min(i, i_max);
@@ -1932,42 +2349,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1932
2349
  const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
1933
2350
  const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
1934
2351
 
1935
- #ifdef NEW_MMA_AVAILABLE
2352
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1936
2353
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
1937
2354
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
1938
2355
  #else
1939
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0;
1940
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1;
1941
- #endif // NEW_MMA_AVAILABLE
2356
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
2357
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
2358
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1942
2359
  }
1943
2360
 
1944
2361
  const int ls = aux32 >> 28;
1945
2362
  const float d = bxi->d;
1946
- #ifdef NEW_MMA_AVAILABLE
1947
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
2363
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2364
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
1948
2365
  #else
1949
- x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4;
1950
- #endif // NEW_MMA_AVAILABLE
2366
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
2367
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1951
2368
  }
1952
2369
  }
1953
2370
 
1954
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
2371
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
1955
2372
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2373
+ constexpr int nwarps = mmq_get_nwarps_device();
2374
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1956
2375
 
1957
- #ifdef NEW_MMA_AVAILABLE
2376
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1958
2377
  int * x_qs = (int *) x_tile;
1959
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2378
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1960
2379
  #else
1961
2380
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
1962
2381
  int * x_qs = (int *) x_tile;
1963
2382
  float * x_df = (float *) (x_qs + txs.qs);
1964
- #endif // NEW_MMA_AVAILABLE
2383
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1965
2384
 
1966
- const int kqsx = threadIdx.x % (QI2_XS/2);
2385
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
2386
+ constexpr int nrows = warp_size / threads_per_row;
2387
+ const int kqsx = threadIdx.x % threads_per_row;
1967
2388
 
1968
2389
  #pragma unroll
1969
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) {
1970
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2);
2390
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2391
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1971
2392
 
1972
2393
  if (need_check) {
1973
2394
  i = min(i, i_max);
@@ -1986,44 +2407,48 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1986
2407
  const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
1987
2408
  const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
1988
2409
 
1989
- #ifdef NEW_MMA_AVAILABLE
2410
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1990
2411
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
1991
2412
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
1992
2413
  #else
1993
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
1994
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
1995
- #endif // NEW_MMA_AVAILABLE
2414
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2415
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2416
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1996
2417
  }
1997
2418
 
1998
2419
  const int ls = bxi->scales[kqsx];
1999
2420
  const float d = bxi->d;
2000
- #ifdef NEW_MMA_AVAILABLE
2001
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2002
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2421
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2422
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2423
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2003
2424
  #else
2004
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2005
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2006
- #endif // NEW_MMA_AVAILABLE
2425
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2426
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2427
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2007
2428
  }
2008
2429
  }
2009
2430
 
2010
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
2431
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
2011
2432
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2433
+ constexpr int nwarps = mmq_get_nwarps_device();
2434
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2012
2435
 
2013
- #ifdef NEW_MMA_AVAILABLE
2436
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2014
2437
  int * x_qs = (int *) x_tile;
2015
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2438
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2016
2439
  #else
2017
2440
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
2018
2441
  int * x_qs = (int *) x_tile;
2019
2442
  float * x_df = (float *) (x_qs + txs.qs);
2020
- #endif // NEW_MMA_AVAILABLE
2443
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2021
2444
 
2022
- const int kqsx = threadIdx.x % (QI2_S/2);
2445
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
2446
+ constexpr int nrows = warp_size / threads_per_row;
2447
+ const int kqsx = threadIdx.x % threads_per_row;
2023
2448
 
2024
2449
  #pragma unroll
2025
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) {
2026
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2);
2450
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2451
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2027
2452
 
2028
2453
  if (need_check) {
2029
2454
  i = min(i, i_max);
@@ -2049,44 +2474,48 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2049
2474
  const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
2050
2475
  const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
2051
2476
 
2052
- #ifdef NEW_MMA_AVAILABLE
2477
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2053
2478
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2054
2479
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2055
2480
  #else
2056
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2057
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2058
- #endif // NEW_MMA_AVAILABLE
2481
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2482
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2483
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2059
2484
  }
2060
2485
 
2061
2486
  const int ls = bxi->scales[kqsx];
2062
2487
  const float d = bxi->d;
2063
- #ifdef NEW_MMA_AVAILABLE
2064
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2065
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2488
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2489
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2490
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2066
2491
  #else
2067
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2068
- x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2069
- #endif // NEW_MMA_AVAILABLE
2492
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2493
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2494
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2070
2495
  }
2071
2496
  }
2072
2497
 
2073
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2498
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2074
2499
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2500
+ constexpr int nwarps = mmq_get_nwarps_device();
2501
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2075
2502
 
2076
- #ifdef NEW_MMA_AVAILABLE
2503
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2077
2504
  int * x_qs = (int *) x_tile;
2078
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2505
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2079
2506
  #else
2080
2507
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
2081
2508
  int * x_qs = (int *) x_tile;
2082
2509
  float * x_df = (float *) (x_qs + txs.qs);
2083
- #endif // NEW_MMA_AVAILABLE
2510
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2084
2511
 
2085
- const int kqsx = threadIdx.x % (QI3_XXS/2);
2512
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
2513
+ constexpr int nrows = warp_size / threads_per_row;
2514
+ const int kqsx = threadIdx.x % threads_per_row;
2086
2515
 
2087
2516
  #pragma unroll
2088
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) {
2089
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2);
2517
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2518
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2090
2519
 
2091
2520
  if (need_check) {
2092
2521
  i = min(i, i_max);
@@ -2107,42 +2536,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2107
2536
  const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
2108
2537
  const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
2109
2538
 
2110
- #ifdef NEW_MMA_AVAILABLE
2539
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2111
2540
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
2112
2541
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
2113
2542
  #else
2114
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2115
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2116
- #endif // NEW_MMA_AVAILABLE
2543
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2544
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2545
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2117
2546
  }
2118
2547
 
2119
2548
  const int ls = aux32 >> 28;
2120
2549
  const float d = bxi->d;
2121
- #ifdef NEW_MMA_AVAILABLE
2122
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2550
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2551
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2123
2552
  #else
2124
- x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2125
- #endif // NEW_MMA_AVAILABLE
2553
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2554
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2126
2555
  }
2127
2556
  }
2128
2557
 
2129
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2558
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2130
2559
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2560
+ constexpr int nwarps = mmq_get_nwarps_device();
2561
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2131
2562
 
2132
- #ifdef NEW_MMA_AVAILABLE
2563
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2133
2564
  int * x_qs = (int *) x_tile;
2134
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2565
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2135
2566
  #else
2136
2567
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2137
2568
  int * x_qs = (int *) x_tile;
2138
2569
  float * x_df = (float *) (x_qs + txs.qs);
2139
- #endif // NEW_MMA_AVAILABLE
2570
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2140
2571
 
2141
- const int kqsx = threadIdx.x % (QI3_S/2);
2572
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
2573
+ constexpr int nrows = warp_size / threads_per_row;
2574
+ const int kqsx = threadIdx.x % threads_per_row;
2142
2575
 
2143
2576
  #pragma unroll
2144
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) {
2145
- int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2);
2577
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2578
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2146
2579
 
2147
2580
  if (need_check) {
2148
2581
  i = min(i, i_max);
@@ -2170,42 +2603,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2170
2603
  const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2171
2604
  const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2172
2605
 
2173
- #ifdef NEW_MMA_AVAILABLE
2606
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2174
2607
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
2175
2608
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
2176
2609
  #else
2177
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l;
2178
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h;
2179
- #endif // NEW_MMA_AVAILABLE
2610
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
2611
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
2612
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2180
2613
  }
2181
2614
 
2182
2615
  const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
2183
2616
  const float d = bxi->d;
2184
- #ifdef NEW_MMA_AVAILABLE
2185
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2617
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2618
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2186
2619
  #else
2187
- x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d;
2188
- #endif // NEW_MMA_AVAILABLE
2620
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
2621
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2189
2622
  }
2190
2623
  }
2191
2624
 
2192
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
2625
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
2193
2626
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2627
+ constexpr int nwarps = mmq_get_nwarps_device();
2628
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2194
2629
 
2195
- #ifdef NEW_MMA_AVAILABLE
2630
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2196
2631
  int * x_qs = (int *) x_tile;
2197
- half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
2632
+ half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
2198
2633
  #else
2199
2634
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2200
2635
  int * x_qs = (int *) x_tile;
2201
2636
  half2 * x_ds = (half2 *) (x_qs + txs.qs);
2202
- #endif // NEW_MMA_AVAILABLE
2637
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2203
2638
 
2204
- const int kqsx = threadIdx.x % QI1_S;
2639
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
2640
+ constexpr int nrows = warp_size / threads_per_row;
2641
+ const int kqsx = threadIdx.x % threads_per_row;
2205
2642
 
2206
2643
  #pragma unroll
2207
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) {
2208
- int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S;
2644
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2645
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2209
2646
 
2210
2647
  if (need_check) {
2211
2648
  i = min(i, i_max);
@@ -2225,66 +2662,71 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2225
2662
  const int grid0 = (grid >> 0) & 0x0F0F0F0F;
2226
2663
  const int grid1 = (grid >> 4) & 0x0F0F0F0F;
2227
2664
 
2228
- #ifdef NEW_MMA_AVAILABLE
2665
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2229
2666
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
2230
2667
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
2231
2668
  #else
2232
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0;
2233
- x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1;
2234
- #endif // NEW_MMA_AVAILABLE
2669
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
2670
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
2671
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2235
2672
  }
2236
2673
 
2237
2674
  const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
2238
2675
  const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
2239
2676
 
2240
- #ifdef NEW_MMA_AVAILABLE
2241
- x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
2677
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2678
+ x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
2242
2679
  #else
2243
- x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
2244
- #endif // NEW_MMA_AVAILABLE
2680
+ x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
2681
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2245
2682
  }
2246
2683
  }
2247
2684
 
2248
- template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
2685
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
2249
2686
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2687
+ constexpr int nwarps = mmq_get_nwarps_device();
2688
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2250
2689
 
2251
- #ifdef NEW_MMA_AVAILABLE
2690
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2252
2691
  int * x_qs = (int *) x_tile;
2253
- float * x_df = (float *) (x_qs + WARP_SIZE*2);
2692
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2254
2693
  #else
2255
2694
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2256
2695
  int * x_qs = (int *) x_tile;
2257
2696
  float * x_df = (float *) (x_qs + txs.qs);
2258
- #endif // NEW_MMA_AVAILABLE
2697
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2259
2698
 
2260
- const int kbx = 0; // threadIdx.x / QI4_XS
2261
- const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
2699
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
2700
+ constexpr int nrows = warp_size / threads_per_row;
2701
+ const int kqsx = threadIdx.x % threads_per_row;
2262
2702
 
2263
2703
  #pragma unroll
2264
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
2265
- int i = i0 + threadIdx.y;
2704
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2705
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2266
2706
 
2267
2707
  if (need_check) {
2268
2708
  i = min(i, i_max);
2269
2709
  }
2270
2710
 
2271
- const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
2711
+ const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
2272
2712
 
2273
2713
  const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2274
2714
  const int2 v = get_int_from_table_16(aux_q4);
2275
- const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2276
- #ifdef NEW_MMA_AVAILABLE
2715
+ const int k0 = 8 * (kqsx / 4) + kqsx % 4;
2716
+
2717
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2277
2718
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2278
2719
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2279
2720
  #else
2280
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2281
- x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2282
- #endif // NEW_MMA_AVAILABLE
2721
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2722
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
2723
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2283
2724
  }
2284
2725
 
2726
+ constexpr int rows_per_warp = warp_size / 8;
2285
2727
  #pragma unroll
2286
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
2287
- int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
2728
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
2729
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4);
2288
2730
 
2289
2731
  if (need_check) {
2290
2732
  i = min(i, i_max);
@@ -2297,18 +2739,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2297
2739
  const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
2298
2740
  | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
2299
2741
 
2300
- #ifdef NEW_MMA_AVAILABLE
2301
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
2742
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2743
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
2302
2744
  #else
2303
- x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2304
- #endif // NEW_MMA_AVAILABLE
2745
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2746
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2305
2747
  }
2306
2748
  }
2307
2749
 
2308
- template<int mmq_x, int mmq_y, int nwarps, bool need_check>
2750
+ template<int mmq_x, int mmq_y, bool need_check>
2309
2751
  static __device__ __forceinline__ void mmq_write_back_dp4a(
2310
2752
  const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
2311
2753
  const int stride, const int i_max, const int j_max) {
2754
+ constexpr int nwarps = mmq_get_nwarps_device();
2755
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2756
+
2312
2757
  #pragma unroll
2313
2758
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2314
2759
  const int j = j0 + threadIdx.y;
@@ -2318,32 +2763,40 @@ static __device__ __forceinline__ void mmq_write_back_dp4a(
2318
2763
  }
2319
2764
 
2320
2765
  #pragma unroll
2321
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
2766
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2322
2767
  const int i = i0 + threadIdx.x;
2323
2768
 
2324
2769
  if (need_check && i > i_max) {
2325
2770
  continue;
2326
2771
  }
2327
2772
 
2328
- dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
2773
+ dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
2329
2774
  }
2330
2775
  }
2331
2776
  }
2332
2777
 
2333
- template<int mmq_x, int mmq_y, int nwarps, bool need_check>
2778
+ template<ggml_type type, int mmq_x, int mmq_y, bool need_check>
2334
2779
  static __device__ __forceinline__ void mmq_write_back_mma(
2335
2780
  const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
2336
2781
  const int stride, const int i_max, const int j_max) {
2337
- typedef tile<16, 8, int> tile_C;
2338
2782
 
2339
2783
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
2784
+ constexpr int nwarps = mmq_get_nwarps_device();
2785
+
2786
+ #if defined(AMD_MFMA_AVAILABLE)
2787
+ constexpr int tileC_IJ = mmq_get_granularity_device(0);
2788
+ typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
2789
+ constexpr int rows_per_warp = granularity;
2790
+ #else
2791
+ typedef tile<16, 8, int> tile_C;
2340
2792
  constexpr int rows_per_warp = 2 * granularity;
2793
+ #endif
2341
2794
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2342
2795
 
2343
2796
  const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
2344
- #ifdef NEW_MMA_AVAILABLE
2797
+ #if defined(NEW_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
2345
2798
  static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
2346
- #endif // NEW_MMA_AVAILABLE
2799
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2347
2800
 
2348
2801
  #pragma unroll
2349
2802
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
@@ -2371,179 +2824,181 @@ static __device__ __forceinline__ void mmq_write_back_mma(
2371
2824
 
2372
2825
  // -------------------------------------------------------------------------------------------------------------------------------------
2373
2826
 
2374
- template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
2827
+ template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
2375
2828
  struct mmq_type_traits;
2376
2829
 
2377
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2378
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
2830
+ template <int mmq_x, int mmq_y, bool need_check>
2831
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
2379
2832
  static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
2380
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
2381
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>;
2382
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2833
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, need_check>;
2834
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
2835
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;
2383
2836
  };
2384
2837
 
2385
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2386
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
2838
+ template <int mmq_x, int mmq_y, bool need_check>
2839
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
2387
2840
  static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
2388
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
2389
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2390
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2841
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, need_check>;
2842
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
2843
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>;
2391
2844
  };
2392
2845
 
2393
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2394
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
2846
+ template <int mmq_x, int mmq_y, bool need_check>
2847
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
2395
2848
  static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
2396
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
2397
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2398
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2849
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, need_check>;
2850
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2851
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2399
2852
  };
2400
2853
 
2401
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2402
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
2854
+ template <int mmq_x, int mmq_y, bool need_check>
2855
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
2403
2856
  static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
2404
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
2405
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2406
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2857
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, need_check>;
2858
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
2859
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
2407
2860
  };
2408
2861
 
2409
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2410
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
2862
+ template <int mmq_x, int mmq_y, bool need_check>
2863
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
2411
2864
  static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
2412
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
2413
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2414
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2865
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, need_check>;
2866
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2867
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2415
2868
  };
2416
2869
 
2417
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2418
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
2870
+ template <int mmq_x, int mmq_y, bool need_check>
2871
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
2419
2872
  static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
2420
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
2421
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
2422
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2873
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, need_check>;
2874
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;
2875
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>;
2423
2876
  };
2424
2877
 
2425
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2426
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
2878
+ template <int mmq_x, int mmq_y, bool need_check>
2879
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
2427
2880
  static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
2428
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
2429
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
2430
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2881
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, need_check>;
2882
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
2883
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>;
2431
2884
  };
2432
2885
 
2433
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2434
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
2886
+ template <int mmq_x, int mmq_y, bool need_check>
2887
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
2435
2888
  static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
2436
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
2437
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2438
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2889
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, need_check>;
2890
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
2891
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>;
2439
2892
  };
2440
2893
 
2441
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2442
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
2894
+ template <int mmq_x, int mmq_y, bool need_check>
2895
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
2443
2896
  static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
2444
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
2445
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2446
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2897
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, need_check>;
2898
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
2899
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>;
2447
2900
  };
2448
2901
 
2449
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2450
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
2902
+ template <int mmq_x, int mmq_y, bool need_check>
2903
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
2451
2904
  static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
2452
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
2453
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
2454
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2905
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, need_check>;
2906
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;
2907
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>;
2455
2908
  };
2456
2909
 
2457
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2458
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
2910
+ template <int mmq_x, int mmq_y, bool need_check>
2911
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
2459
2912
  static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
2460
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, nwarps, need_check>;
2461
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2462
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2913
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, need_check>;
2914
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2915
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2463
2916
  };
2464
2917
 
2465
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2466
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> {
2918
+ template <int mmq_x, int mmq_y, bool need_check>
2919
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
2467
2920
  static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
2468
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, nwarps, need_check>;
2469
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
2470
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2921
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, need_check>;
2922
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
2923
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
2471
2924
  };
2472
2925
 
2473
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2474
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> {
2926
+ template <int mmq_x, int mmq_y, bool need_check>
2927
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
2475
2928
  static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
2476
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, nwarps, need_check>;
2477
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
2478
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2929
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, need_check>;
2930
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
2931
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
2479
2932
  };
2480
2933
 
2481
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2482
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> {
2934
+ template <int mmq_x, int mmq_y, bool need_check>
2935
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
2483
2936
  static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
2484
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, nwarps, need_check>;
2485
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2486
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2937
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, need_check>;
2938
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2939
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2487
2940
  };
2488
2941
 
2489
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2490
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> {
2942
+ template <int mmq_x, int mmq_y, bool need_check>
2943
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
2491
2944
  static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
2492
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, nwarps, need_check>;
2493
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2494
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2945
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, need_check>;
2946
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2947
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2495
2948
  };
2496
2949
 
2497
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2498
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> {
2950
+ template <int mmq_x, int mmq_y, bool need_check>
2951
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
2499
2952
  static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
2500
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, nwarps, need_check>;
2501
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2502
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2953
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>;
2954
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
2955
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
2503
2956
  };
2504
2957
 
2505
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2506
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
2958
+ template <int mmq_x, int mmq_y, bool need_check>
2959
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
2507
2960
  static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
2508
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
2509
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2510
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2961
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, need_check>;
2962
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2963
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2511
2964
  };
2512
2965
 
2513
- template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2514
- struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
2966
+ template <int mmq_x, int mmq_y, bool need_check>
2967
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
2515
2968
  static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
2516
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
2517
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2518
- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2969
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, need_check>;
2970
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2971
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2519
2972
  };
2520
2973
 
2521
- template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
2974
+ template <ggml_type type, int mmq_x, bool need_check, bool fixup>
2522
2975
  static __device__ __forceinline__ void mul_mat_q_process_tile(
2523
2976
  const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
2524
2977
  const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2525
2978
  const int stride_row_x, const int ncols_y, const int stride_col_dst,
2526
2979
  const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
2527
2980
 
2981
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2982
+ constexpr int nwarps = mmq_get_nwarps_device();
2528
2983
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2529
2984
  constexpr int mmq_y = get_mmq_y_device();
2530
- constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
2985
+ constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles;
2531
2986
 
2532
2987
  extern __shared__ int data_mul_mat_q[];
2533
2988
  int * tile_y = data_mul_mat_q + mmq_x;
2534
- int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
2989
+ int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
2535
2990
 
2536
- #ifdef NEW_MMA_AVAILABLE
2537
- constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
2538
- constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
2991
+ #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2992
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
2993
+ constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
2539
2994
  #else
2540
- constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
2541
- constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
2542
- #endif // NEW_MMA_AVAILABLE
2995
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
2996
+ constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
2997
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2543
2998
 
2544
2999
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
2545
3000
 
2546
- float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
3001
+ float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
2547
3002
 
2548
3003
  for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
2549
3004
  load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
@@ -2551,8 +3006,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
2551
3006
  {
2552
3007
  const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
2553
3008
  #pragma unroll
2554
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
2555
- int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3009
+ for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
3010
+ int l = l0 + threadIdx.y*warp_size + threadIdx.x;
2556
3011
 
2557
3012
  tile_y[l] = by0[l];
2558
3013
  }
@@ -2567,8 +3022,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
2567
3022
  {
2568
3023
  const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
2569
3024
  #pragma unroll
2570
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
2571
- int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3025
+ for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
3026
+ int l = l0 + threadIdx.y*warp_size + threadIdx.x;
2572
3027
 
2573
3028
  tile_y[l] = by0[l];
2574
3029
  }
@@ -2576,7 +3031,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
2576
3031
 
2577
3032
  __syncthreads();
2578
3033
 
2579
- vec_dot(tile_x, tile_y, sum, WARP_SIZE);
3034
+ vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K);
2580
3035
 
2581
3036
  __syncthreads();
2582
3037
  }
@@ -2591,16 +3046,16 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
2591
3046
 
2592
3047
  // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
2593
3048
 
2594
- template <ggml_type type, int mmq_x, int nwarps, bool need_check>
3049
+ template <ggml_type type, int mmq_x, bool need_check>
2595
3050
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2596
3051
  #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2597
- __launch_bounds__(WARP_SIZE*nwarps, 2)
3052
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
2598
3053
  #endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2599
3054
  #else
2600
3055
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
2601
- __launch_bounds__(WARP_SIZE*nwarps, 1)
3056
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1)
2602
3057
  #else
2603
- __launch_bounds__(WARP_SIZE*nwarps, 2)
3058
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
2604
3059
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
2605
3060
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2606
3061
  static __global__ void mul_mat_q(
@@ -2616,6 +3071,9 @@ static __global__ void mul_mat_q(
2616
3071
  return;
2617
3072
  }
2618
3073
 
3074
+ constexpr int nwarps = mmq_get_nwarps_device();
3075
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3076
+
2619
3077
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2620
3078
  constexpr int mmq_y = get_mmq_y_device();
2621
3079
 
@@ -2627,10 +3085,10 @@ static __global__ void mul_mat_q(
2627
3085
  // For MoE the correct indices are loaded from ids_dst.
2628
3086
  extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
2629
3087
  #pragma unroll
2630
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2631
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3088
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3089
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
2632
3090
 
2633
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
3091
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
2634
3092
  break;
2635
3093
  }
2636
3094
 
@@ -2639,7 +3097,7 @@ static __global__ void mul_mat_q(
2639
3097
  __syncthreads();
2640
3098
 
2641
3099
  // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
2642
- #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3100
+ #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
2643
3101
  {
2644
3102
  const int wt = blockIdx.z / nchannels_y;
2645
3103
  const int zt = blockIdx.z - wt*nchannels_y;
@@ -2667,10 +3125,10 @@ static __global__ void mul_mat_q(
2667
3125
 
2668
3126
  // __syncthreads(); // There is no previous tile that could cause a race condition.
2669
3127
  #pragma unroll
2670
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2671
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3128
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3129
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
2672
3130
 
2673
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
3131
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
2674
3132
  break;
2675
3133
  }
2676
3134
 
@@ -2688,12 +3146,12 @@ static __global__ void mul_mat_q(
2688
3146
  const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
2689
3147
 
2690
3148
  constexpr bool fixup = false;
2691
- mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
3149
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
2692
3150
  (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
2693
3151
  tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
2694
3152
  return;
2695
3153
  }
2696
- #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3154
+ #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
2697
3155
 
2698
3156
  const int64_t blocks_per_ne00 = ncols_x / qk;
2699
3157
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
@@ -2745,10 +3203,10 @@ static __global__ void mul_mat_q(
2745
3203
 
2746
3204
  __syncthreads();
2747
3205
  #pragma unroll
2748
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2749
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3206
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3207
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
2750
3208
 
2751
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
3209
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
2752
3210
  break;
2753
3211
  }
2754
3212
 
@@ -2766,7 +3224,7 @@ static __global__ void mul_mat_q(
2766
3224
  const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
2767
3225
 
2768
3226
  constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
2769
- mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
3227
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
2770
3228
  (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
2771
3229
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2772
3230
 
@@ -2812,10 +3270,10 @@ static __global__ void mul_mat_q(
2812
3270
  // The memory layout for the fixup buffer is always contiguous, therefore reset ids:
2813
3271
  __syncthreads();
2814
3272
  #pragma unroll
2815
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2816
- const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
3273
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3274
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
2817
3275
 
2818
- if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
3276
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
2819
3277
  break;
2820
3278
  }
2821
3279
 
@@ -2833,13 +3291,13 @@ static __global__ void mul_mat_q(
2833
3291
  const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
2834
3292
 
2835
3293
  constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
2836
- mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
3294
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
2837
3295
  (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
2838
3296
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2839
3297
  }
2840
3298
 
2841
3299
 
2842
- template <ggml_type type, int mmq_x, int nwarps, bool need_check>
3300
+ template <ggml_type type, int mmq_x, bool need_check>
2843
3301
  static __global__ void mul_mat_q_stream_k_fixup(
2844
3302
  const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
2845
3303
  const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
@@ -2849,7 +3307,10 @@ static __global__ void mul_mat_q_stream_k_fixup(
2849
3307
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
2850
3308
  const int64_t blocks_per_ne00 = ncols_x / qk;
2851
3309
 
2852
- float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
3310
+ constexpr int nwarps = mmq_get_nwarps_device();
3311
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3312
+
3313
+ float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
2853
3314
 
2854
3315
  const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
2855
3316
  const int nty = (nrows_x + mmq_y - 1) / mmq_y;
@@ -2893,10 +3354,10 @@ static __global__ void mul_mat_q_stream_k_fixup(
2893
3354
  const int j = j0 + threadIdx.y;
2894
3355
 
2895
3356
  #pragma unroll
2896
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
3357
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2897
3358
  const int i = i0 + threadIdx.x;
2898
3359
 
2899
- sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
3360
+ sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
2900
3361
  }
2901
3362
  }
2902
3363
 
@@ -2937,14 +3398,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
2937
3398
  }
2938
3399
 
2939
3400
  #pragma unroll
2940
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
3401
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2941
3402
  const int i = i0 + threadIdx.x;
2942
3403
 
2943
3404
  if (need_check && i > i_max) {
2944
3405
  continue;
2945
3406
  }
2946
3407
 
2947
- dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
3408
+ dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
2948
3409
  }
2949
3410
  }
2950
3411
  return;
@@ -2955,7 +3416,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
2955
3416
  const int col_high = expert_bounds[zt + 1];
2956
3417
  const int col_diff = col_high - col_low;
2957
3418
 
2958
- for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) {
3419
+ for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
2959
3420
  ids_dst_shared[j] = ids_dst[col_low + j];
2960
3421
  }
2961
3422
  __syncthreads();
@@ -2975,14 +3436,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
2975
3436
  }
2976
3437
 
2977
3438
  #pragma unroll
2978
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
3439
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2979
3440
  const int i = i0 + threadIdx.x;
2980
3441
 
2981
3442
  if (need_check && i > i_max) {
2982
3443
  continue;
2983
3444
  }
2984
3445
 
2985
- dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
3446
+ dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
2986
3447
  }
2987
3448
  }
2988
3449
  }
@@ -2996,13 +3457,13 @@ struct mmq_args {
2996
3457
  };
2997
3458
 
2998
3459
  template<ggml_type type>
2999
- static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) {
3460
+ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
3000
3461
  const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
3001
3462
  const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
3002
3463
  const size_t nbs_ids = mmq_x*sizeof(int);
3003
- const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3464
+ const size_t nbs_x = (new_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3004
3465
  const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
3005
- return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
3466
+ return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
3006
3467
  }
3007
3468
 
3008
3469
  template <ggml_type type, int mmq_x>
@@ -3010,20 +3471,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3010
3471
  const int id = ggml_cuda_get_device();
3011
3472
  const int cc = ggml_cuda_info().devices[id].cc;
3012
3473
  const int nsm = ggml_cuda_info().devices[id].nsm;
3474
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
3475
+ const int nwarps = mmq_get_nwarps_host(cc);
3013
3476
  const int mmq_y = get_mmq_y_host(cc);
3014
3477
 
3015
- const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
3478
+ const dim3 block_dims(warp_size, nwarps, 1);
3016
3479
 
3017
- const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
3480
+ const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);
3018
3481
 
3019
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3020
- static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
3021
- if (!shared_memory_limit_raised[id]) {
3022
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3023
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3024
- shared_memory_limit_raised[id] = true;
3025
- }
3026
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3482
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared);
3483
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
3027
3484
 
3028
3485
  const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
3029
3486
  const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
@@ -3038,14 +3495,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3038
3495
  if (!args.use_stream_k) {
3039
3496
  if (args.nrows_x % mmq_y == 0) {
3040
3497
  constexpr bool need_check = false;
3041
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3498
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3042
3499
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3043
3500
  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3044
3501
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3045
3502
  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3046
3503
  } else {
3047
3504
  constexpr bool need_check = true;
3048
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3505
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3049
3506
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3050
3507
  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3051
3508
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
@@ -3065,8 +3522,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3065
3522
 
3066
3523
  if (args.nrows_x % mmq_y == 0) {
3067
3524
  constexpr bool need_check = false;
3068
-
3069
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3525
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3070
3526
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3071
3527
  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3072
3528
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
@@ -3076,13 +3532,12 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3076
3532
  return;
3077
3533
  }
3078
3534
 
3079
- mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3535
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3080
3536
  (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3081
3537
  args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
3082
3538
  } else {
3083
3539
  constexpr bool need_check = true;
3084
-
3085
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3540
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3086
3541
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3087
3542
  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3088
3543
  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
@@ -3092,7 +3547,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3092
3547
  return;
3093
3548
  }
3094
3549
 
3095
- mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3550
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3096
3551
  (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3097
3552
  args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
3098
3553
  }
@@ -3100,9 +3555,11 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3100
3555
 
3101
3556
  template <ggml_type type>
3102
3557
  void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
3103
- const int id = ggml_cuda_get_device();
3104
- const int cc = ggml_cuda_info().devices[id].cc;
3105
- const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
3558
+ const int id = ggml_cuda_get_device();
3559
+ const int cc = ggml_cuda_info().devices[id].cc;
3560
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
3561
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
3562
+ const int nwarps = mmq_get_nwarps_host(cc);
3106
3563
 
3107
3564
  const int mmq_x_max = get_mmq_x_max_host(cc);
3108
3565
  const int mmq_y = get_mmq_y_host(cc);
@@ -3113,7 +3570,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
3113
3570
  for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
3114
3571
  const int granularity = mmq_get_granularity_host(mmq_x, cc);
3115
3572
 
3116
- if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc) > smpbo) {
3573
+ if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {
3117
3574
  continue;
3118
3575
  }
3119
3576