whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -10,9 +10,9 @@
10
10
  using namespace ggml_cuda_mma;
11
11
 
12
12
  #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
13
- #define MMQ_ITER_K 256
14
- #define MMQ_ITER_K_MXFP4_FP4 512
15
- #define MMQ_NWARPS 8
13
+ #define MMQ_ITER_K 256
14
+ #define MMQ_ITER_K_FP4 512
15
+ #define MMQ_NWARPS 8
16
16
 
17
17
  typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
18
18
  typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
@@ -46,9 +46,12 @@ struct block_q8_1_mmq {
46
46
  int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
47
47
  };
48
48
 
49
+ // this struct is used for fp4 data types (currently only used for Blackwell)
50
+ // mxfp4 has block size 32, each int32 of d4 contains 2 e8m0 scales in the lower 16 bits
51
+ // nvfp4 has block size 16, each int32 of d4 contains 4 ue4m3 scales
49
52
  struct block_fp4_mmq {
50
- uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
51
- int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
53
+ uint32_t d4[4];
54
+ int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte)
52
55
  };
53
56
 
54
57
  static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
@@ -57,6 +60,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b
57
60
 
58
61
  static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
59
62
  switch (type_x) {
63
+ case GGML_TYPE_Q1_0:
64
+ return MMQ_Q8_1_DS_LAYOUT_D4;
60
65
  case GGML_TYPE_Q4_0:
61
66
  case GGML_TYPE_Q4_1:
62
67
  return MMQ_Q8_1_DS_LAYOUT_DS4;
@@ -68,6 +73,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
68
73
  return MMQ_Q8_1_DS_LAYOUT_D4;
69
74
  case GGML_TYPE_MXFP4:
70
75
  return MMQ_Q8_1_DS_LAYOUT_D4;
76
+ case GGML_TYPE_NVFP4:
77
+ return MMQ_Q8_1_DS_LAYOUT_D4;
71
78
  case GGML_TYPE_Q2_K:
72
79
  return MMQ_Q8_1_DS_LAYOUT_D2S6;
73
80
  case GGML_TYPE_Q3_K:
@@ -100,7 +107,7 @@ struct tile_x_sizes {
100
107
  };
101
108
 
102
109
  static int get_mmq_x_max_host(const int cc) {
103
- return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
110
+ return (turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
104
111
  GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
105
112
  #ifdef GGML_CUDA_FORCE_MMQ
106
113
  128 : 64;
@@ -110,9 +117,9 @@ static int get_mmq_x_max_host(const int cc) {
110
117
  }
111
118
 
112
119
  static constexpr __device__ int get_mmq_x_max_device() {
113
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
120
+ #if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
114
121
  return 128;
115
- #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
122
+ #else // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
116
123
 
117
124
  #if defined(GGML_USE_HIP)
118
125
  return 64;
@@ -139,10 +146,11 @@ static int get_mmq_y_host(const int cc) {
139
146
 
140
147
  static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
141
148
  #if defined(BLACKWELL_MMA_AVAILABLE)
142
- return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
143
- #else
144
- return MMQ_ITER_K;
149
+ if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) {
150
+ return MMQ_ITER_K_FP4;
151
+ }
145
152
  #endif // defined(BLACKWELL_MMA_AVAILABLE)
153
+ return MMQ_ITER_K;
146
154
  }
147
155
 
148
156
  static constexpr __device__ int get_mmq_y_device() {
@@ -183,12 +191,14 @@ static constexpr __device__ int get_mmq_y_device() {
183
191
 
184
192
  static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
185
193
  switch (type) {
194
+ case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0;
186
195
  case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
187
196
  case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
188
197
  case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
189
198
  case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
190
199
  case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
191
200
  case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
201
+ case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16;
192
202
  case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
193
203
  case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
194
204
  case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
@@ -206,12 +216,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
206
216
  }
207
217
  }
208
218
 
209
- #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
210
- #define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
211
- #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
212
- #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
213
- #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
214
- #define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
219
+ #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
220
+ #define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 and NVFP4 Blackwell
221
+ #define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 Generic
222
+ #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
223
+ #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
224
+ #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
225
+ #define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
215
226
 
216
227
  static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
217
228
  static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
@@ -220,9 +231,12 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
220
231
  static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
221
232
  static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
222
233
  static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
234
+ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
235
+
223
236
 
224
237
  static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
225
238
  switch (type) {
239
+ case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0;
226
240
  case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
227
241
  case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
228
242
  case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
@@ -230,6 +244,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
230
244
  case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
231
245
  // tile sizes are the same for Q8_1 and FP4 for blackwell
232
246
  case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
247
+ #if defined(BLACKWELL_MMA_AVAILABLE)
248
+ case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_FP4;
249
+ #else
250
+ case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
251
+ #endif // defined(BLACKWELL_MMA_AVAILABLE)
233
252
  case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
234
253
  case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
235
254
  case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@@ -295,6 +314,87 @@ static constexpr __device__ int mmq_get_nwarps_device() {
295
314
 
296
315
  // ------------------------------------------------------------
297
316
 
317
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q1_0(
318
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
319
+ constexpr int nwarps = mmq_get_nwarps_device();
320
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
321
+
322
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
323
+ int * x_qs = (int *) x_tile;
324
+ float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
325
+ #else
326
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
327
+ int * x_qs = (int *) x_tile;
328
+ float * x_df = (float *) (x_qs + txs.qs);
329
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
330
+
331
+ constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0;
332
+ constexpr int threads_per_row = blocks_per_iter * QI1_0;
333
+ constexpr int nrows = warp_size / threads_per_row;
334
+ constexpr int scale_entries_per_block = QK1_0 / QK8_1;
335
+ constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block;
336
+
337
+ const int txi = threadIdx.x % threads_per_row;
338
+ const int kbx = txi / QI1_0;
339
+ const int kqsx = txi % QI1_0;
340
+
341
+ #pragma unroll
342
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
343
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
344
+
345
+ if (need_check) {
346
+ i = min(i, i_max);
347
+ }
348
+
349
+ const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx;
350
+ const int qs_offset = 4*kqsx;
351
+ const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) |
352
+ (bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24);
353
+
354
+ int unpacked_bytes[8];
355
+ #pragma unroll
356
+ for (int j = 0; j < 8; ++j) {
357
+ const int shift = j * 4;
358
+ const int bits4 = (qs0 >> shift) & 0x0F;
359
+ const int b0 = (bits4 & 0x01) ? 1 : -1;
360
+ const int b1 = (bits4 & 0x02) ? 1 : -1;
361
+ const int b2 = (bits4 & 0x04) ? 1 : -1;
362
+ const int b3 = (bits4 & 0x08) ? 1 : -1;
363
+ unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
364
+ }
365
+
366
+ const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0;
367
+ #pragma unroll
368
+ for (int j = 0; j < 8; ++j) {
369
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
370
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j];
371
+ #else
372
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j];
373
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
374
+ }
375
+ }
376
+
377
+ const int ksx = threadIdx.x % scale_entries_per_row;
378
+ const int scale_block = ksx / scale_entries_per_block;
379
+
380
+ #pragma unroll
381
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
382
+ int i = i0 + threadIdx.y;
383
+
384
+ if (need_check) {
385
+ i = min(i, i_max);
386
+ }
387
+
388
+ const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block;
389
+
390
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
391
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d;
392
+ #else
393
+ x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d;
394
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
395
+ }
396
+ }
397
+
298
398
  template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
299
399
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
300
400
  constexpr int nwarps = mmq_get_nwarps_device();
@@ -379,17 +479,25 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
379
479
  #pragma unroll
380
480
  for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
381
481
  const int i = i0 + threadIdx.x;
382
-
383
482
  const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
384
483
 
385
484
  int u[2*VDR_Q4_0_Q8_1_MMQ];
386
485
 
387
- #pragma unroll
388
- for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
389
- u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
390
- u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
486
+ constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
487
+ constexpr int mcpy_int = max_cpy / sizeof(int);
488
+ static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
489
+
490
+ int tmp0[4], tmp1[4];
491
+
492
+ #pragma unroll
493
+ for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
494
+ ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
495
+ ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0 + l0 * mcpy_int]);
391
496
  }
392
497
 
498
+ u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
499
+ u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
500
+
393
501
  sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
394
502
  (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
395
503
  x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
@@ -482,17 +590,25 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
482
590
  #pragma unroll
483
591
  for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
484
592
  const int i = i0 + threadIdx.x;
485
-
486
593
  const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
487
594
 
488
595
  int u[2*VDR_Q4_1_Q8_1_MMQ];
489
596
 
490
- #pragma unroll
491
- for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
492
- u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
493
- u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
597
+ constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
598
+ constexpr int mcpy_int = max_cpy / sizeof(int);
599
+ static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
600
+
601
+ int tmp0[4], tmp1[4];
602
+
603
+ #pragma unroll
604
+ for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
605
+ ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
606
+ ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1 + l0 * mcpy_int]);
494
607
  }
495
608
 
609
+ u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
610
+ u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
611
+
496
612
  sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
497
613
  (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
498
614
  x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
@@ -826,6 +942,187 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
826
942
  }
827
943
  }
828
944
 
945
+ #ifdef BLACKWELL_MMA_AVAILABLE
946
+ template <int mmq_y, bool need_check>
947
+ static __device__ __forceinline__ void load_tiles_nvfp4_nvfp4(const char * __restrict__ x,
948
+ int * __restrict__ x_tile,
949
+ const int kbx0,
950
+ const int i_max,
951
+ const int stride) {
952
+ constexpr int nwarps = mmq_get_nwarps_device();
953
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
954
+ constexpr int iter_k = get_iter_k(GGML_TYPE_NVFP4);
955
+ constexpr int threads_per_row = iter_k / QK_NVFP4; // each thread processes 1 block
956
+ constexpr int rows_per_warp = warp_size / threads_per_row;
957
+
958
+ uint32_t * x_u32 = (uint32_t *) x_tile;
959
+
960
+ const int txi = threadIdx.x;
961
+ const int kbx = txi % threads_per_row;
962
+ const int row_in_warp = txi / threads_per_row;
963
+
964
+ const block_nvfp4 * bxi_base = (const block_nvfp4 *) x + kbx0 + kbx;
965
+ uint32_t * x_u32_scale = x_u32 + 64 + kbx;
966
+
967
+ #pragma unroll
968
+ for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
969
+ int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
970
+
971
+ if constexpr (need_check) {
972
+ i = min(i, i_max);
973
+ }
974
+
975
+ const block_nvfp4 * bxi = bxi_base + i * stride;
976
+ const int row_base = i * MMQ_MMA_TILE_X_K_FP4;
977
+ const int q_base = row_base + 8 * kbx;
978
+
979
+ const uint32_t * src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
980
+
981
+ #pragma unroll
982
+ for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
983
+ x_u32[q_base + 2 * sub + 0] = src_qs[2 * sub + 0];
984
+ x_u32[q_base + 2 * sub + 1] = src_qs[2 * sub + 1];
985
+ }
986
+
987
+ x_u32_scale[row_base] = get_int_b4(bxi->d, 0);
988
+ }
989
+ }
990
+
991
+ // Shared MMA kernel for MXFP4 and NVFP4 on Blackwell.
992
+ // Both quantizations encode values as e2m1 (FP4) and produce one uint32 scale per
993
+ // m16n8k64 MMA call; only the PTX kind (scale_vec::2X ue8m0 vs scale_vec::4X ue4m3)
994
+ // and the per-type stride constant differ.
995
+ template <int mmq_x, int mmq_y, ggml_type type>
996
+ static __device__ __forceinline__ void vec_dot_fp4_fp4_mma(const int * __restrict__ x,
997
+ const int * __restrict__ y,
998
+ float * __restrict__ sum,
999
+ const int k00) {
1000
+ static_assert(type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4,
1001
+ "vec_dot_fp4_fp4_mma: type must be MXFP4 or NVFP4");
1002
+
1003
+ typedef tile<16, 8, int> tile_A;
1004
+ typedef tile<8, 8, int> tile_B;
1005
+ typedef tile<16, 8, float> tile_C;
1006
+
1007
+ constexpr int stride = MMQ_MMA_TILE_X_K_FP4;
1008
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1009
+ constexpr int rows_per_warp = 2 * granularity;
1010
+ constexpr int ntx = rows_per_warp / tile_C::I;
1011
+ constexpr int nfrags = MMQ_TILE_NE_K / tile_A::J;
1012
+
1013
+ y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_K);
1014
+
1015
+ const int * x_qs = (const int *) x;
1016
+ const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
1017
+ const int * y_qs = (const int *) y + 4;
1018
+ const uint32_t * y_sc = (const uint32_t *) y;
1019
+
1020
+ // 2 threads per quad supply the packed scale register to the block_scale MMA,
1021
+ // see https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1022
+ const int tidx_A = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
1023
+ const int tidx_B = threadIdx.x / 4;
1024
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1025
+
1026
+ tile_A A[ntx][nfrags];
1027
+ uint32_t scaleA[ntx][nfrags];
1028
+
1029
+ #pragma unroll
1030
+ for (int n = 0; n < ntx; ++n) {
1031
+ #pragma unroll
1032
+ for (int frag = 0; frag < nfrags; ++frag) {
1033
+ const int k0 = k00 + frag * tile_A::J;
1034
+ load_ldmatrix(A[n][frag], x_qs + (i0 + n * tile_A::I) * stride + k0, stride);
1035
+ scaleA[n][frag] = x_sc[(i0 + n * tile_A::I + tidx_A) * stride + k0 / tile_A::J];
1036
+ }
1037
+ }
1038
+
1039
+ #pragma unroll
1040
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
1041
+ tile_B B[nfrags];
1042
+ uint32_t scaleB[nfrags];
1043
+
1044
+ #pragma unroll
1045
+ for (int frag = 0; frag < nfrags; ++frag) {
1046
+ const int k0 = frag * tile_B::J;
1047
+ load_generic(B[frag], y_qs + j0 * MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
1048
+ scaleB[frag] = y_sc[(j0 + tidx_B) * MMQ_TILE_Y_K + frag];
1049
+ }
1050
+
1051
+ #pragma unroll
1052
+ for (int n = 0; n < ntx; ++n) {
1053
+ #pragma unroll
1054
+ for (int frag = 0; frag < nfrags; ++frag) {
1055
+ tile_C C = {};
1056
+ mma_block_scaled_fp4<type>(C, A[n][frag], B[frag], scaleA[n][frag], scaleB[frag]);
1057
+ #pragma unroll
1058
+ for (int l = 0; l < tile_C::ne; ++l) {
1059
+ sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
1060
+ }
1061
+ }
1062
+ }
1063
+ }
1064
+ }
1065
+ #endif // BLACKWELL_MMA_AVAILABLE
1066
+
1067
+
1068
+ template <int mmq_y, bool need_check>
1069
+ static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
1070
+ int * __restrict__ x_tile,
1071
+ const int kb0,
1072
+ const int i_max,
1073
+ const int stride) {
1074
+ constexpr int nwarps = mmq_get_nwarps_device();
1075
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1076
+
1077
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1078
+ int * x_qs = (int *) x_tile;
1079
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1080
+ #else
1081
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y);
1082
+ int * x_qs = (int *) x_tile;
1083
+ float * x_df = (float *) (x_qs + txs.qs);
1084
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1085
+
1086
+ constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4;
1087
+ constexpr int rows_per_warp = warp_size / threads_per_row;
1088
+ const int kbx = threadIdx.x % threads_per_row;
1089
+ const int row_in_warp = threadIdx.x / threads_per_row;
1090
+
1091
+ #pragma unroll
1092
+ for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
1093
+ int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
1094
+
1095
+ if constexpr (need_check) {
1096
+ i = min(i, i_max);
1097
+ }
1098
+
1099
+ const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx;
1100
+ const uint32_t * __restrict__ src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
1101
+ const int kqs = 16 * kbx;
1102
+ const int ksc = 4 * kbx;
1103
+
1104
+ #pragma unroll
1105
+ for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
1106
+ const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4);
1107
+ const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4);
1108
+
1109
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1110
+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x;
1111
+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x;
1112
+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y;
1113
+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y;
1114
+ x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
1115
+ #else
1116
+ x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x;
1117
+ x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x;
1118
+ x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y;
1119
+ x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y;
1120
+ x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
1121
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1122
+ }
1123
+ }
1124
+ }
1125
+
829
1126
  template <int mmq_x, int mmq_y>
830
1127
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
831
1128
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -887,13 +1184,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
887
1184
  tile_A A[ntx];
888
1185
  #pragma unroll
889
1186
  for (int n = 0; n < ntx; ++n) {
890
- load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
1187
+ load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
891
1188
  }
892
1189
 
893
1190
  #pragma unroll
894
1191
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
895
1192
  tile_B B;
896
- load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1193
+ load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
897
1194
 
898
1195
  float dB;
899
1196
  const int j = j0 + tile_C::get_j(0);
@@ -996,77 +1293,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
996
1293
  #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
997
1294
  }
998
1295
 
999
- template <int mmq_x, int mmq_y>
1000
- static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
1001
- const int * __restrict__ y,
1002
- float * __restrict__ sum,
1003
- const int k00) {
1004
- typedef tile<16, 8, int> tile_A;
1005
- typedef tile<8, 8, int> tile_B;
1006
- typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
1007
-
1008
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
1009
- constexpr int rows_per_warp = 2 * granularity;
1010
- constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
1011
-
1012
- y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
1013
-
1014
- // Match layout from load_tiles_mxfp4_fp4
1015
- const int * x_qs = (const int *) x;
1016
- const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
1017
- const int * y_qs = (const int *) y + 4;
1018
- const uint32_t * y_sc = (const uint32_t *) y;
1019
-
1020
- // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
1021
- tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1022
- uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1023
-
1024
- // Block scale
1025
- // Each thread has to point to a 4 byte scale value
1026
- // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1027
-
1028
- const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1029
-
1030
- #pragma unroll
1031
- for (int n = 0; n < ntx; ++n) {
1032
- #pragma unroll
1033
- for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1034
- const int k0 = k00 + k01;
1035
-
1036
- load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
1037
- MMQ_MMA_TILE_X_K_FP4);
1038
-
1039
- // based on block-scaling document, 2 threads in each quad need to supply to the scale value
1040
- const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
1041
- scaleA[n][k01 / (2 * QI_MXFP4)] =
1042
- *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
1043
- }
1044
- }
1045
-
1046
- #pragma unroll
1047
- for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
1048
- #pragma unroll
1049
- for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1050
- tile_B B;
1051
- uint32_t scaleB; // 2xN scales
1052
-
1053
- load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
1054
-
1055
- scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
1056
-
1057
- #pragma unroll
1058
- for (int n = 0; n < ntx; ++n) {
1059
- tile_C C;
1060
-
1061
- mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
1062
- #pragma unroll
1063
- for (int l = 0; l < tile_C::ne; ++l) {
1064
- sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
1065
- }
1066
- }
1067
- }
1068
- }
1069
- }
1070
1296
 
1071
1297
  template <int mmq_x, int mmq_y>
1072
1298
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
@@ -1128,13 +1354,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
1128
1354
  tile_A A[ntx];
1129
1355
  #pragma unroll
1130
1356
  for (int n = 0; n < ntx; ++n) {
1131
- load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
1357
+ load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
1132
1358
  }
1133
1359
 
1134
1360
  #pragma unroll
1135
1361
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1136
1362
  tile_B B;
1137
- load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1363
+ load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1138
1364
 
1139
1365
  const int j = j0 + tile_C::get_j(0);
1140
1366
  const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
@@ -1229,7 +1455,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
1229
1455
  #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1230
1456
  }
1231
1457
 
1232
- // Used for Q3_K, IQ2_S, and IQ2_XS
1458
+ // Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS
1233
1459
  template <int mmq_x, int mmq_y>
1234
1460
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
1235
1461
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -1268,57 +1494,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
1268
1494
  template <int mmq_x, int mmq_y>
1269
1495
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1270
1496
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1271
- #if defined(AMD_MFMA_AVAILABLE)
1272
- constexpr data_layout input_layout = get_input_data_layout();
1273
- typedef tile<16, 8, int, input_layout> tile_A;
1274
- typedef tile<16, 8, int, input_layout> tile_B;
1275
- typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1276
- typedef tile<64, 2, int, input_layout> tile_load;
1277
-
1278
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
1279
- constexpr int rows_per_warp = granularity;
1280
- constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1281
-
1282
- y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1283
-
1284
- const int * x_qs = (const int *) x;
1285
- const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1286
- const int * y_qs = (const int *) y + 4;
1287
- const float * y_df = (const float *) y;
1288
-
1289
- const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1290
-
1291
- for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1292
- const int k0 = k00 + k01;
1293
-
1294
- tile_A A[ntx];
1295
- #pragma unroll
1296
- for (int n = 0; n < ntx; ++n) {
1297
- load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1298
- }
1299
-
1300
- #pragma unroll
1301
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1302
- tile_B B[1];
1303
- load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1304
-
1305
- const int j = j0 + tile_C::get_j(0);
1306
- const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
1307
-
1308
- #pragma unroll
1309
- for (int n = 0; n < ntx; ++n) {
1310
- tile_C C;
1311
- mma(C, A[n], B[0]);
1312
-
1313
- #pragma unroll
1314
- for (int l = 0; l < tile_C::ne; ++l) {
1315
- const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1316
- sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
1317
- }
1318
- }
1319
- }
1320
- }
1321
- #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1497
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1322
1498
  constexpr data_layout input_layout = get_input_data_layout();
1323
1499
  typedef tile<16, 4, int, input_layout> tile_A;
1324
1500
  typedef tile<16, 4, int, input_layout> tile_B;
@@ -1343,13 +1519,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1343
1519
  tile_A A[ntx];
1344
1520
  #pragma unroll
1345
1521
  for (int n = 0; n < ntx; ++n) {
1346
- load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1522
+ load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1347
1523
  }
1348
1524
 
1349
1525
  #pragma unroll
1350
1526
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1351
1527
  tile_B B;
1352
- load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1528
+ load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1353
1529
 
1354
1530
  const int j = j0 + tile_C::get_j(0);
1355
1531
  const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
@@ -1575,74 +1751,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1575
1751
  template <int mmq_x, int mmq_y>
1576
1752
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1577
1753
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1578
- #if defined(AMD_MFMA_AVAILABLE)
1579
- constexpr data_layout input_layout = get_input_data_layout();
1580
- typedef tile<16, 8, int, input_layout> tile_A;
1581
- typedef tile<16, 8, int, input_layout> tile_B;
1582
- typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1583
- typedef tile<64, 2, int, input_layout> tile_load;
1584
-
1585
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
1586
- constexpr int rows_per_warp = granularity;
1587
- constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1588
-
1589
- y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1590
-
1591
- const int * x_qs = (const int *) x;
1592
- const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1593
- const int * y_qs = (const int *) y + 4;
1594
- const half2 * y_ds = (const half2 *) y;
1595
-
1596
- const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1597
-
1598
- for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1599
- const int k0 = k00 + k01;
1600
-
1601
- tile_A A[ntx];
1602
- #pragma unroll
1603
- for (int n = 0; n < ntx; ++n) {
1604
- load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1605
- }
1606
-
1607
- #pragma unroll
1608
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1609
- tile_B B[1];
1610
- load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1611
-
1612
- const int j = j0 + tile_C::get_j(0);
1613
- const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
1614
- const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
1615
- : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1616
- : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
1617
-
1618
- tile_C Cm;
1619
- if (k01 >= MMQ_TILE_NE_K * 3/4) {
1620
- tile_A A1;
1621
- A1.x[0] = 0x01010101;
1622
- A1.x[1] = 0x01010101;
1623
- mma(Cm, A1, B[0]);
1624
- }
1625
-
1626
- #pragma unroll
1627
- for (int n = 0; n < ntx; ++n) {
1628
- tile_C Cd;
1629
- mma(Cd, A[n], B[0]);
1630
-
1631
- #pragma unroll
1632
- for (int l = 0; l < tile_C::ne; ++l) {
1633
- const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1634
- const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
1635
- float tmp = Cd.x[l]*dm.x;
1636
- if (k01 >= MMQ_TILE_NE_K * 3/4) {
1637
- tmp -= Cm.x[l]*dm.y;
1638
- }
1639
- sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
1640
- sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
1641
- }
1642
- }
1643
- }
1644
- }
1645
- #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1754
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1646
1755
  constexpr data_layout input_layout = get_input_data_layout();
1647
1756
  typedef tile<16, 4, int, input_layout> tile_A;
1648
1757
  typedef tile<16, 4, int, input_layout> tile_B;
@@ -1667,13 +1776,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1667
1776
  tile_A A[ntx];
1668
1777
  #pragma unroll
1669
1778
  for (int n = 0; n < ntx; ++n) {
1670
- load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1779
+ load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1671
1780
  }
1672
1781
 
1673
1782
  #pragma unroll
1674
1783
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1675
1784
  tile_B B;
1676
- load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1785
+ load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1677
1786
 
1678
1787
  const int j = j0 + tile_C::get_j(0);
1679
1788
  const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
@@ -2406,59 +2515,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
2406
2515
  template <int mmq_x, int mmq_y>
2407
2516
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2408
2517
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2409
- #if defined(AMD_MFMA_AVAILABLE)
2410
- constexpr data_layout input_layout = get_input_data_layout();
2411
- typedef tile<16, 8, int, input_layout> tile_A;
2412
- typedef tile<16, 8, int, input_layout> tile_B;
2413
- typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
2414
- typedef tile<64, 2, int, input_layout> tile_load;
2415
-
2416
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
2417
- constexpr int rows_per_warp = granularity;
2418
- constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2419
-
2420
- y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2421
-
2422
- const int * x_qs = (const int *) x;
2423
- const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2424
- const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2425
- const int * y_qs = (const int *) y + 4;
2426
- const float * y_df = (const float *) y;
2427
-
2428
- const int i0 = (threadIdx.y / ntx) * rows_per_warp;
2429
-
2430
- for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2431
- const int k0 = k00 + k01;
2432
-
2433
- tile_A A[ntx];
2434
- #pragma unroll
2435
- for (int n = 0; n < ntx; ++n) {
2436
- load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2437
- }
2438
-
2439
- #pragma unroll
2440
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2441
- tile_B B[1];
2442
- load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2443
-
2444
- const int j = j0 + tile_C::get_j(0);
2445
- const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
2446
-
2447
- #pragma unroll
2448
- for (int n = 0; n < ntx; ++n) {
2449
- tile_C C;
2450
- mma(C, A[n], B[0]);
2451
-
2452
- #pragma unroll
2453
- for (int l = 0; l < tile_C::ne; ++l) {
2454
- const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2455
- const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2456
- sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2457
- }
2458
- }
2459
- }
2460
- }
2461
- #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
2518
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2462
2519
  constexpr data_layout input_layout = get_input_data_layout();
2463
2520
  typedef tile<16, 4, int, input_layout> tile_A;
2464
2521
  typedef tile<16, 4, int, input_layout> tile_B;
@@ -2484,13 +2541,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2484
2541
  tile_A A[ntx];
2485
2542
  #pragma unroll
2486
2543
  for (int n = 0; n < ntx; ++n) {
2487
- load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2544
+ load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2488
2545
  }
2489
2546
 
2490
2547
  #pragma unroll
2491
2548
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2492
2549
  tile_B B;
2493
- load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2550
+ load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2494
2551
 
2495
2552
  const int j = j0 + tile_C::get_j(0);
2496
2553
  const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
@@ -3208,6 +3265,14 @@ static __device__ __forceinline__ void mmq_write_back_mma(
3208
3265
  template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
3209
3266
  struct mmq_type_traits;
3210
3267
 
3268
+ template <int mmq_x, int mmq_y, bool need_check>
3269
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> {
3270
+ static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ;
3271
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0<mmq_y, need_check>;
3272
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3273
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3274
+ };
3275
+
3211
3276
  template <int mmq_x, int mmq_y, bool need_check>
3212
3277
  struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
3213
3278
  static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
@@ -3253,7 +3318,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
3253
3318
  static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
3254
3319
  #ifdef BLACKWELL_MMA_AVAILABLE
3255
3320
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
3256
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
3321
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_MXFP4>;
3257
3322
  #else
3258
3323
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
3259
3324
  static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3261,6 +3326,19 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
3261
3326
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3262
3327
  };
3263
3328
 
3329
+ template <int mmq_x, int mmq_y, bool need_check>
3330
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
3331
+ static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
3332
+ #ifdef BLACKWELL_MMA_AVAILABLE
3333
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4_nvfp4<mmq_y, need_check>;
3334
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_NVFP4>;
3335
+ #else
3336
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
3337
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3338
+ #endif // BLACKWELL_MMA_AVAILABLE
3339
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
3340
+ };
3341
+
3264
3342
  template <int mmq_x, int mmq_y, bool need_check>
3265
3343
  struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
3266
3344
  static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
@@ -3392,7 +3470,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
3392
3470
 
3393
3471
  #if defined(BLACKWELL_MMA_AVAILABLE)
3394
3472
  // FP4 tile stores 8 blocks
3395
- constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
3473
+ constexpr int ne_block = (type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4) ? QK_K : 4 * QK8_1;
3396
3474
  #else
3397
3475
  constexpr int ne_block = 4 * QK8_1;
3398
3476
  #endif // defined(BLACKWELL_MMA_AVAILABLE)
@@ -3464,10 +3542,10 @@ template <ggml_type type, int mmq_x, bool need_check>
3464
3542
  static __global__ void mul_mat_q(
3465
3543
  const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
3466
3544
  const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
3467
- const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
3468
- const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
3469
- const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
3470
- const int ncols_max) {
3545
+ const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
3546
+ const uint3 channel_ratio, const uint3 nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
3547
+ const uint3 sample_ratio, const uint3 nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
3548
+ const uint3 ntx) {
3471
3549
 
3472
3550
  // Skip unused template specializations for faster compilation:
3473
3551
  if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
@@ -3481,8 +3559,7 @@ static __global__ void mul_mat_q(
3481
3559
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
3482
3560
  constexpr int mmq_y = get_mmq_y_device();
3483
3561
 
3484
- const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
3485
- const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
3562
+ const uint32_t nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
3486
3563
 
3487
3564
  // Initialize the ids for writing back data with just the index.
3488
3565
  // For regular matrix multiplications this is never changed.
@@ -3503,8 +3580,9 @@ static __global__ void mul_mat_q(
3503
3580
  // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
3504
3581
  #if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3505
3582
  {
3506
- const int wt = blockIdx.z / nchannels_y;
3507
- const int zt = blockIdx.z - wt*nchannels_y;
3583
+ const uint2 tmp2 = fast_div_modulo(blockIdx.z, nchannels_y);
3584
+ const int wt = tmp2.x;
3585
+ const int zt = tmp2.y;
3508
3586
  const int jt = blockIdx.y;
3509
3587
  const int it = blockIdx.x;
3510
3588
 
@@ -3547,40 +3625,40 @@ static __global__ void mul_mat_q(
3547
3625
  const int tile_x_max_i = nrows_x - it*mmq_y - 1;
3548
3626
  const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3549
3627
 
3550
- const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3628
+ const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3551
3629
 
3552
3630
  constexpr bool fixup = false;
3553
3631
  mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
3554
3632
  (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
3555
- tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
3633
+ tile_x_max_i, tile_y_max_j, 0, blocks_per_ne00.z);
3556
3634
  return;
3557
3635
  }
3558
- #endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3559
-
3560
- constexpr int ITER_K = get_iter_k(type);
3636
+ #endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3561
3637
 
3562
- const int64_t blocks_per_ne00 = ncols_x / qk;
3563
- constexpr int blocks_per_iter = ITER_K / qk;
3638
+ constexpr int ITER_K = get_iter_k(type);
3639
+ constexpr int blocks_per_iter = ITER_K / qk;
3564
3640
 
3565
3641
  // kbc == k block continuous, current index in continuous ijk space.
3566
- int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3567
- int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3642
+ int kbc = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
3643
+ int kbc_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
3568
3644
 
3569
- kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
3570
- kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
3645
+ kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter;
3646
+ kbc_stop -= fastmodulo(kbc_stop, blocks_per_ne00) % blocks_per_iter;
3571
3647
 
3572
3648
  // kb0 == k index when doing the matrix multiplication for an output tile.
3573
- int kb0_start = kbc % blocks_per_ne00;
3574
- int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
3575
- while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
3576
- int tmp = kbc;
3577
- const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3578
- tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3579
- const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3580
- tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3581
- const int zt = tmp / (ntx*blocks_per_ne00);
3582
- tmp -= zt * (ntx*blocks_per_ne00);
3583
- const int jt = tmp / blocks_per_ne00;
3649
+ int kb0_start = fastmodulo(kbc, blocks_per_ne00);
3650
+ int kb0_stop = min(blocks_per_ne00.z, uint32_t(kb0_start + kbc_stop - kbc));
3651
+ while (kbc < kbc_stop && kb0_stop == int(blocks_per_ne00.z)) {
3652
+ int tmp = fastdiv(kbc, blocks_per_ne00);
3653
+ uint2 tmp2 = fast_div_modulo(tmp, ntx);
3654
+ const int jt = tmp2.y;
3655
+ tmp = tmp2.x;
3656
+ tmp2 = fast_div_modulo(tmp, nchannels_y);
3657
+ const int zt = tmp2.y;
3658
+ tmp = tmp2.x;
3659
+ tmp2 = fast_div_modulo(tmp, nsamples_y);
3660
+ const int wt = tmp2.y;
3661
+ const int it = tmp2.x;
3584
3662
 
3585
3663
  // Defaults for regular matrix multiplication:
3586
3664
  int col_low = 0;
@@ -3598,11 +3676,11 @@ static __global__ void mul_mat_q(
3598
3676
  offset_dst = 0;
3599
3677
 
3600
3678
  if (jt*mmq_x >= col_diff) {
3601
- kbc += blocks_per_ne00;
3602
- kbc -= kbc % blocks_per_ne00;
3679
+ kbc += blocks_per_ne00.z;
3680
+ kbc -= fastmodulo(kbc, blocks_per_ne00);
3603
3681
 
3604
3682
  kb0_start = 0;
3605
- kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
3683
+ kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc));
3606
3684
 
3607
3685
  continue;
3608
3686
  }
@@ -3627,32 +3705,34 @@ static __global__ void mul_mat_q(
3627
3705
  const int tile_x_max_i = nrows_x - it*mmq_y - 1;
3628
3706
  const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3629
3707
 
3630
- const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3708
+ const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3631
3709
 
3632
3710
  constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
3633
3711
  mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
3634
3712
  (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
3635
3713
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
3636
3714
 
3637
- kbc += blocks_per_ne00;
3638
- kbc -= kbc % blocks_per_ne00;
3715
+ kbc += blocks_per_ne00.z;
3716
+ kbc -= fastmodulo(kbc, blocks_per_ne00);
3639
3717
 
3640
3718
  kb0_start = 0;
3641
- kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
3719
+ kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc));
3642
3720
  }
3643
3721
 
3644
3722
  if (kbc >= kbc_stop) {
3645
3723
  return;
3646
3724
  }
3647
3725
 
3648
- int tmp = kbc;
3649
- const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3650
- tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3651
- const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3652
- tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3653
- const int zt = tmp / (ntx*blocks_per_ne00);
3654
- tmp -= zt * (ntx*blocks_per_ne00);
3655
- const int jt = tmp / blocks_per_ne00;
3726
+ int tmp = fastdiv(kbc, blocks_per_ne00);
3727
+ uint2 tmp2 = fast_div_modulo(tmp, ntx);
3728
+ const int jt = tmp2.y;
3729
+ tmp = tmp2.x;
3730
+ tmp2 = fast_div_modulo(tmp, nchannels_y);
3731
+ const int zt = tmp2.y;
3732
+ tmp = tmp2.x;
3733
+ tmp2 = fast_div_modulo(tmp, nsamples_y);
3734
+ const int wt = tmp2.y;
3735
+ const int it = tmp2.x;
3656
3736
 
3657
3737
  // Defaults for regular matrix multiplication:
3658
3738
  int col_low = 0;
@@ -3694,7 +3774,7 @@ static __global__ void mul_mat_q(
3694
3774
  const int tile_x_max_i = nrows_x - it*mmq_y - 1;
3695
3775
  const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3696
3776
 
3697
- const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3777
+ const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3698
3778
 
3699
3779
  constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
3700
3780
  mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
@@ -3703,46 +3783,37 @@ static __global__ void mul_mat_q(
3703
3783
  }
3704
3784
 
3705
3785
  template <ggml_type type, int mmq_x, bool need_check>
3706
- static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
3707
- const int32_t * expert_bounds,
3708
- float * __restrict__ dst,
3709
- const float * __restrict__ tmp_last_tile,
3710
- const int ncols_x,
3711
- const int nrows_x,
3712
- const int ncols_dst,
3713
- const size_t stride_col_dst,
3714
- const int nchannels_y,
3715
- const size_t stride_channel_dst,
3716
- const int nsamples_y,
3717
- const size_t stride_sample_dst,
3718
- const int ncols_max) {
3719
- constexpr int mmq_y = get_mmq_y_device();
3720
- constexpr int qk = ggml_cuda_type_traits<type>::qk;
3721
- constexpr int ITER_K = get_iter_k(type);
3722
-
3723
- constexpr int blocks_per_iter = ITER_K / qk;
3724
- const int64_t blocks_per_ne00 = ncols_x / qk;
3786
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device()/2, 1)
3787
+ static __global__ void mul_mat_q_stream_k_fixup(
3788
+ const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
3789
+ float * __restrict__ tmp_last_tile, const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst,
3790
+ const int stride_col_dst, const uint3 nchannels_y, const int stride_channel_dst, const uint3 nsamples_y,
3791
+ const int stride_sample_dst, const uint3 ntx) {
3792
+ constexpr int mmq_y = get_mmq_y_device();
3793
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
3794
+ constexpr int ITER_K = get_iter_k(type);
3795
+ constexpr int blocks_per_iter = ITER_K / qk;
3725
3796
 
3726
- constexpr int nwarps = mmq_get_nwarps_device();
3797
+ constexpr int nwarps = mmq_get_nwarps_device()/2;
3727
3798
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3728
3799
 
3729
- float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
3800
+ float sum[mmq_x / nwarps] = {0.0f};
3801
+ const int i = blockIdx.y*warp_size + threadIdx.x;
3730
3802
 
3731
- const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
3732
- const int nty = (nrows_x + mmq_y - 1) / mmq_y;
3803
+ const int nty = (nrows_x + mmq_y - 1) / mmq_y;
3733
3804
 
3734
3805
  const int bidx0 = blockIdx.x;
3735
3806
 
3736
3807
  // kbc == k block continuous, current index in continuous ijk space.
3737
- int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3738
- int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3808
+ int kbc0 = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
3809
+ int kbc0_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
3739
3810
 
3740
- kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter;
3741
- kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter;
3811
+ kbc0 -= fastmodulo(kbc0, blocks_per_ne00) % blocks_per_iter;
3812
+ kbc0_stop -= fastmodulo(kbc0_stop, blocks_per_ne00) % blocks_per_iter;
3742
3813
 
3743
3814
  const bool did_not_have_any_data = kbc0 == kbc0_stop;
3744
- const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0;
3745
- const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0;
3815
+ const bool wrote_beginning_of_tile = fastmodulo(kbc0, blocks_per_ne00) == 0;
3816
+ const bool did_not_write_last = fastdiv(kbc0, blocks_per_ne00) == fastdiv(kbc0_stop, blocks_per_ne00) && fastmodulo(kbc0_stop, blocks_per_ne00) != 0;
3746
3817
  if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
3747
3818
  return;
3748
3819
  }
@@ -3751,11 +3822,11 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
3751
3822
 
3752
3823
  // Iterate over previous blocks and sum up partial sums written to fixup buffer.
3753
3824
  // All CUDA blocks that get here must have a previous block that needs a fixup.
3754
- int64_t bidx = bidx0 - 1;
3755
- int64_t kbc_stop = kbc0;
3825
+ int bidx = bidx0 - 1;
3826
+ int kbc_stop = kbc0;
3756
3827
  while(true) {
3757
- int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3758
- kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
3828
+ int kbc = int64_t(bidx)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
3829
+ kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter;
3759
3830
 
3760
3831
  if (kbc == kbc_stop) { // Did not have any data.
3761
3832
  bidx--;
@@ -3765,20 +3836,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
3765
3836
 
3766
3837
  any_fixup = true;
3767
3838
 
3839
+
3768
3840
  #pragma unroll
3769
3841
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
3770
3842
  const int j = j0 + threadIdx.y;
3771
3843
 
3772
- #pragma unroll
3773
- for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3774
- const int i = i0 + threadIdx.x;
3775
-
3776
- sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
3777
- }
3844
+ sum[j0/nwarps] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
3778
3845
  }
3779
3846
 
3780
3847
  // If this block started in a previous tile we are done and don't need to combine additional partial results.
3781
- if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) {
3848
+ if (fastmodulo(kbc, blocks_per_ne00) == 0 || fastdiv(kbc, blocks_per_ne00) < fastdiv(kbc0, blocks_per_ne00)) {
3782
3849
  break;
3783
3850
  }
3784
3851
  bidx--;
@@ -3789,14 +3856,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
3789
3856
  return;
3790
3857
  }
3791
3858
 
3792
- int tmp = kbc0;
3793
- const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3794
- tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3795
- const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3796
- tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3797
- const int zt = tmp / (ntx*blocks_per_ne00);
3798
- tmp -= zt * (ntx*blocks_per_ne00);
3799
- const int jt = tmp / blocks_per_ne00;
3859
+ int tmp = fastdiv(kbc0, blocks_per_ne00);
3860
+ uint2 tmp2 = fast_div_modulo(tmp, ntx);
3861
+ const int jt = tmp2.y;
3862
+ tmp = tmp2.x;
3863
+ tmp2 = fast_div_modulo(tmp, nchannels_y);
3864
+ const int zt = tmp2.y;
3865
+ tmp = tmp2.x;
3866
+ tmp2 = fast_div_modulo(tmp, nsamples_y);
3867
+ const int wt = tmp2.y;
3868
+ const int it = tmp2.x;
3800
3869
 
3801
3870
  if (!ids_dst) {
3802
3871
  const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
@@ -3804,6 +3873,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
3804
3873
 
3805
3874
  const int i_max = nrows_x - it*mmq_y - 1;
3806
3875
  const int j_max = ncols_dst - jt*mmq_x - 1;
3876
+ if (need_check && i > i_max) {
3877
+ return;
3878
+ }
3807
3879
 
3808
3880
  #pragma unroll
3809
3881
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -3813,16 +3885,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
3813
3885
  return;
3814
3886
  }
3815
3887
 
3816
- #pragma unroll
3817
- for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3818
- const int i = i0 + threadIdx.x;
3819
-
3820
- if (need_check && i > i_max) {
3821
- continue;
3822
- }
3823
-
3824
- dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
3825
- }
3888
+ dst[j*stride_col_dst + i] += sum[j0/nwarps];
3826
3889
  }
3827
3890
  return;
3828
3891
  }
@@ -3842,6 +3905,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
3842
3905
 
3843
3906
  const int i_max = nrows_x - it*mmq_y - 1;
3844
3907
  const int j_max = col_diff - jt*mmq_x - 1;
3908
+ if (need_check && i > i_max) {
3909
+ return;
3910
+ }
3845
3911
 
3846
3912
  #pragma unroll
3847
3913
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -3851,16 +3917,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
3851
3917
  return;
3852
3918
  }
3853
3919
 
3854
- #pragma unroll
3855
- for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3856
- const int i = i0 + threadIdx.x;
3857
-
3858
- if (need_check && i > i_max) {
3859
- continue;
3860
- }
3861
-
3862
- dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
3863
- }
3920
+ dst[ids_dst_shared[j]*stride_col_dst + i] += sum[j0/nwarps];
3864
3921
  }
3865
3922
  }
3866
3923
 
@@ -3908,29 +3965,44 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3908
3965
  const int channel_ratio = args.nchannels_y / args.nchannels_x;
3909
3966
  const int sample_ratio = args.nsamples_y / args.nsamples_x;
3910
3967
 
3968
+ const uint3 blocks_per_ne00_fd = init_fastdiv_values(args.ncols_x / ggml_cuda_type_traits<type>::qk);
3969
+ const uint3 ntx_fd = init_fastdiv_values(ntx);
3970
+ const uint3 nchannels_y_fd = init_fastdiv_values(args.nchannels_y);
3971
+ const uint3 nsamples_y_fd = init_fastdiv_values(args.nsamples_y);
3972
+ const uint3 channel_ratio_fd = init_fastdiv_values(channel_ratio);
3973
+ const uint3 sample_ratio_fd = init_fastdiv_values(sample_ratio);
3974
+
3911
3975
  if (!args.use_stream_k) {
3912
3976
  if (args.nrows_x % mmq_y == 0) {
3913
3977
  constexpr bool need_check = false;
3914
3978
  mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3915
3979
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3916
- args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3917
- channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3918
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3919
- args.ncols_max);
3980
+ blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3981
+ channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3982
+ sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3983
+ ntx_fd);
3920
3984
  } else {
3921
3985
  constexpr bool need_check = true;
3922
3986
  mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3923
3987
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3924
- args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3925
- channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3926
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3927
- args.ncols_max);
3988
+ blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3989
+ channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3990
+ sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3991
+ ntx_fd);
3928
3992
  }
3929
3993
  return;
3930
3994
  }
3931
3995
 
3932
- const dim3 block_nums_stream_k(nsm, 1, 1);
3933
- const bool fixup_needed = ntx*nty*ntzw % nsm != 0;
3996
+ // For the stream-k kernel it is possible to run it with tiling by setting the number of CUDA blocks equal to the number of tiles.
3997
+ // This is worthwhile if the efficiency of tiling is high and skipping the fixup kernel is more important.
3998
+ const int ntiles_dst = ntx * nty * ntzw;
3999
+ const int tiles_nwaves = (ntiles_dst + nsm - 1) / nsm;
4000
+ const int tiles_efficiency_percent = 100 * ntiles_dst / (nsm*tiles_nwaves);
4001
+ const dim3 block_nums_stream_k(GGML_CUDA_CC_IS_NVIDIA(cc) && tiles_efficiency_percent >= 90 ? ntiles_dst : nsm, 1, 1);
4002
+
4003
+ GGML_ASSERT(ntiles_dst * blocks_per_ne00_fd.z < (1 << 30)); // Assert that variable kbc will not overflow.
4004
+
4005
+ const bool fixup_needed = ntiles_dst % block_nums_stream_k.x != 0;
3934
4006
 
3935
4007
  ggml_cuda_pool & pool = ctx.pool(id);
3936
4008
  ggml_cuda_pool_alloc<float> tmp_fixup(pool);
@@ -3938,40 +4010,45 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3938
4010
  tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);
3939
4011
  }
3940
4012
 
4013
+ const dim3 block_nums_fixup(block_nums_stream_k.x, mmq_y/warp_size, 1);
4014
+ const dim3 block_dims_fixup(block_dims.x, block_dims.y/2, block_dims.z);
4015
+
3941
4016
  if (args.nrows_x % mmq_y == 0) {
3942
4017
  constexpr bool need_check = false;
3943
4018
  mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3944
4019
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3945
- args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3946
- channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3947
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3948
- args.ncols_max);
4020
+ blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
4021
+ channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
4022
+ sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
4023
+ ntx_fd);
3949
4024
 
3950
4025
  if (!fixup_needed) {
3951
4026
  return;
3952
4027
  }
3953
4028
 
3954
- mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3955
- (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3956
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3957
- args.ncols_max);
4029
+ CUDA_CHECK(cudaGetLastError());
4030
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>>
4031
+ (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst,
4032
+ args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst,
4033
+ ntx_fd);
3958
4034
  } else {
3959
4035
  constexpr bool need_check = true;
3960
4036
  mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3961
4037
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3962
- args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3963
- channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3964
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3965
- args.ncols_max);
4038
+ blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
4039
+ channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
4040
+ sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
4041
+ ntx_fd);
3966
4042
 
3967
4043
  if (!fixup_needed) {
3968
4044
  return;
3969
4045
  }
3970
4046
 
3971
- mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3972
- (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3973
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3974
- args.ncols_max);
4047
+ CUDA_CHECK(cudaGetLastError());
4048
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>>
4049
+ (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst,
4050
+ args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst,
4051
+ ntx_fd);
3975
4052
  }
3976
4053
  }
3977
4054
 
@@ -4069,6 +4146,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
4069
4146
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
4070
4147
  extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
4071
4148
  extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
4149
+ extern DECL_MMQ_CASE(GGML_TYPE_NVFP4);
4072
4150
  extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
4073
4151
  extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
4074
4152
  extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
@@ -4095,3 +4173,4 @@ void ggml_cuda_op_mul_mat_q(
4095
4173
  const int64_t src1_padded_row_size, cudaStream_t stream);
4096
4174
 
4097
4175
  bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);
4176
+