whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -10,9 +10,9 @@
10
10
  using namespace ggml_cuda_mma;
11
11
 
12
12
  #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
13
- #define MMQ_ITER_K 256
14
- #define MMQ_ITER_K_MXFP4_FP4 512
15
- #define MMQ_NWARPS 8
13
+ #define MMQ_ITER_K 256
14
+ #define MMQ_ITER_K_FP4 512
15
+ #define MMQ_NWARPS 8
16
16
 
17
17
  typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
18
18
  typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
@@ -46,9 +46,12 @@ struct block_q8_1_mmq {
46
46
  int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
47
47
  };
48
48
 
49
+ // this struct is used for fp4 data types (currently only used for Blackwell)
50
+ // mxfp4 has block size 32, each int32 of d4 contains 2 e8m0 scales in the lower 16 bits
51
+ // nvfp4 has block size 16, each int32 of d4 contains 4 ue4m3 scales
49
52
  struct block_fp4_mmq {
50
- uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
51
- int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
53
+ uint32_t d4[4];
54
+ int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte)
52
55
  };
53
56
 
54
57
  static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
@@ -57,6 +60,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b
57
60
 
58
61
  static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
59
62
  switch (type_x) {
63
+ case GGML_TYPE_Q1_0:
64
+ return MMQ_Q8_1_DS_LAYOUT_D4;
60
65
  case GGML_TYPE_Q4_0:
61
66
  case GGML_TYPE_Q4_1:
62
67
  return MMQ_Q8_1_DS_LAYOUT_DS4;
@@ -68,6 +73,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
68
73
  return MMQ_Q8_1_DS_LAYOUT_D4;
69
74
  case GGML_TYPE_MXFP4:
70
75
  return MMQ_Q8_1_DS_LAYOUT_D4;
76
+ case GGML_TYPE_NVFP4:
77
+ return MMQ_Q8_1_DS_LAYOUT_D4;
71
78
  case GGML_TYPE_Q2_K:
72
79
  return MMQ_Q8_1_DS_LAYOUT_D2S6;
73
80
  case GGML_TYPE_Q3_K:
@@ -100,7 +107,7 @@ struct tile_x_sizes {
100
107
  };
101
108
 
102
109
  static int get_mmq_x_max_host(const int cc) {
103
- return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
110
+ return (turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
104
111
  GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
105
112
  #ifdef GGML_CUDA_FORCE_MMQ
106
113
  128 : 64;
@@ -110,9 +117,9 @@ static int get_mmq_x_max_host(const int cc) {
110
117
  }
111
118
 
112
119
  static constexpr __device__ int get_mmq_x_max_device() {
113
- #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
120
+ #if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
114
121
  return 128;
115
- #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
122
+ #else // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
116
123
 
117
124
  #if defined(GGML_USE_HIP)
118
125
  return 64;
@@ -139,10 +146,11 @@ static int get_mmq_y_host(const int cc) {
139
146
 
140
147
  static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
141
148
  #if defined(BLACKWELL_MMA_AVAILABLE)
142
- return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
143
- #else
144
- return MMQ_ITER_K;
149
+ if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) {
150
+ return MMQ_ITER_K_FP4;
151
+ }
145
152
  #endif // defined(BLACKWELL_MMA_AVAILABLE)
153
+ return MMQ_ITER_K;
146
154
  }
147
155
 
148
156
  static constexpr __device__ int get_mmq_y_device() {
@@ -183,12 +191,14 @@ static constexpr __device__ int get_mmq_y_device() {
183
191
 
184
192
  static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
185
193
  switch (type) {
194
+ case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0;
186
195
  case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
187
196
  case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
188
197
  case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
189
198
  case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
190
199
  case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
191
200
  case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
201
+ case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16;
192
202
  case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
193
203
  case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
194
204
  case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
@@ -206,12 +216,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
206
216
  }
207
217
  }
208
218
 
209
- #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
210
- #define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
211
- #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
212
- #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
213
- #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
214
- #define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
219
+ #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
220
+ #define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 and NVFP4 Blackwell
221
+ #define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 Generic
222
+ #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
223
+ #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
224
+ #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
225
+ #define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
215
226
 
216
227
  static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
217
228
  static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
@@ -220,9 +231,12 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
220
231
  static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
221
232
  static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
222
233
  static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
234
+ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
235
+
223
236
 
224
237
  static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
225
238
  switch (type) {
239
+ case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0;
226
240
  case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
227
241
  case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
228
242
  case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
@@ -230,6 +244,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
230
244
  case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
231
245
  // tile sizes are the same for Q8_1 and FP4 for blackwell
232
246
  case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
247
+ #if defined(BLACKWELL_MMA_AVAILABLE)
248
+ case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_FP4;
249
+ #else
250
+ case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
251
+ #endif // defined(BLACKWELL_MMA_AVAILABLE)
233
252
  case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
234
253
  case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
235
254
  case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@@ -295,6 +314,87 @@ static constexpr __device__ int mmq_get_nwarps_device() {
295
314
 
296
315
  // ------------------------------------------------------------
297
316
 
317
+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q1_0(
318
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
319
+ constexpr int nwarps = mmq_get_nwarps_device();
320
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
321
+
322
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
323
+ int * x_qs = (int *) x_tile;
324
+ float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
325
+ #else
326
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
327
+ int * x_qs = (int *) x_tile;
328
+ float * x_df = (float *) (x_qs + txs.qs);
329
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
330
+
331
+ constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0;
332
+ constexpr int threads_per_row = blocks_per_iter * QI1_0;
333
+ constexpr int nrows = warp_size / threads_per_row;
334
+ constexpr int scale_entries_per_block = QK1_0 / QK8_1;
335
+ constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block;
336
+
337
+ const int txi = threadIdx.x % threads_per_row;
338
+ const int kbx = txi / QI1_0;
339
+ const int kqsx = txi % QI1_0;
340
+
341
+ #pragma unroll
342
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
343
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
344
+
345
+ if (need_check) {
346
+ i = min(i, i_max);
347
+ }
348
+
349
+ const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx;
350
+ const int qs_offset = 4*kqsx;
351
+ const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) |
352
+ (bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24);
353
+
354
+ int unpacked_bytes[8];
355
+ #pragma unroll
356
+ for (int j = 0; j < 8; ++j) {
357
+ const int shift = j * 4;
358
+ const int bits4 = (qs0 >> shift) & 0x0F;
359
+ const int b0 = (bits4 & 0x01) ? 1 : -1;
360
+ const int b1 = (bits4 & 0x02) ? 1 : -1;
361
+ const int b2 = (bits4 & 0x04) ? 1 : -1;
362
+ const int b3 = (bits4 & 0x08) ? 1 : -1;
363
+ unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
364
+ }
365
+
366
+ const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0;
367
+ #pragma unroll
368
+ for (int j = 0; j < 8; ++j) {
369
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
370
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j];
371
+ #else
372
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j];
373
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
374
+ }
375
+ }
376
+
377
+ const int ksx = threadIdx.x % scale_entries_per_row;
378
+ const int scale_block = ksx / scale_entries_per_block;
379
+
380
+ #pragma unroll
381
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
382
+ int i = i0 + threadIdx.y;
383
+
384
+ if (need_check) {
385
+ i = min(i, i_max);
386
+ }
387
+
388
+ const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block;
389
+
390
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
391
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d;
392
+ #else
393
+ x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d;
394
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
395
+ }
396
+ }
397
+
298
398
  template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
299
399
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
300
400
  constexpr int nwarps = mmq_get_nwarps_device();
@@ -379,17 +479,25 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
379
479
  #pragma unroll
380
480
  for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
381
481
  const int i = i0 + threadIdx.x;
382
-
383
482
  const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
384
483
 
385
484
  int u[2*VDR_Q4_0_Q8_1_MMQ];
386
485
 
387
- #pragma unroll
388
- for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
389
- u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
390
- u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
486
+ constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
487
+ constexpr int mcpy_int = max_cpy / sizeof(int);
488
+ static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
489
+
490
+ int tmp0[4], tmp1[4];
491
+
492
+ #pragma unroll
493
+ for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
494
+ ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
495
+ ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0 + l0 * mcpy_int]);
391
496
  }
392
497
 
498
+ u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
499
+ u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
500
+
393
501
  sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
394
502
  (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
395
503
  x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
@@ -482,17 +590,25 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
482
590
  #pragma unroll
483
591
  for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
484
592
  const int i = i0 + threadIdx.x;
485
-
486
593
  const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
487
594
 
488
595
  int u[2*VDR_Q4_1_Q8_1_MMQ];
489
596
 
490
- #pragma unroll
491
- for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
492
- u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
493
- u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
597
+ constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
598
+ constexpr int mcpy_int = max_cpy / sizeof(int);
599
+ static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
600
+
601
+ int tmp0[4], tmp1[4];
602
+
603
+ #pragma unroll
604
+ for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
605
+ ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
606
+ ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1 + l0 * mcpy_int]);
494
607
  }
495
608
 
609
+ u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
610
+ u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
611
+
496
612
  sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
497
613
  (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
498
614
  x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
@@ -826,6 +942,187 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
826
942
  }
827
943
  }
828
944
 
945
+ #ifdef BLACKWELL_MMA_AVAILABLE
946
+ template <int mmq_y, bool need_check>
947
+ static __device__ __forceinline__ void load_tiles_nvfp4_nvfp4(const char * __restrict__ x,
948
+ int * __restrict__ x_tile,
949
+ const int kbx0,
950
+ const int i_max,
951
+ const int stride) {
952
+ constexpr int nwarps = mmq_get_nwarps_device();
953
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
954
+ constexpr int iter_k = get_iter_k(GGML_TYPE_NVFP4);
955
+ constexpr int threads_per_row = iter_k / QK_NVFP4; // each thread processes 1 block
956
+ constexpr int rows_per_warp = warp_size / threads_per_row;
957
+
958
+ uint32_t * x_u32 = (uint32_t *) x_tile;
959
+
960
+ const int txi = threadIdx.x;
961
+ const int kbx = txi % threads_per_row;
962
+ const int row_in_warp = txi / threads_per_row;
963
+
964
+ const block_nvfp4 * bxi_base = (const block_nvfp4 *) x + kbx0 + kbx;
965
+ uint32_t * x_u32_scale = x_u32 + 64 + kbx;
966
+
967
+ #pragma unroll
968
+ for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
969
+ int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
970
+
971
+ if constexpr (need_check) {
972
+ i = min(i, i_max);
973
+ }
974
+
975
+ const block_nvfp4 * bxi = bxi_base + i * stride;
976
+ const int row_base = i * MMQ_MMA_TILE_X_K_FP4;
977
+ const int q_base = row_base + 8 * kbx;
978
+
979
+ const uint32_t * src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
980
+
981
+ #pragma unroll
982
+ for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
983
+ x_u32[q_base + 2 * sub + 0] = src_qs[2 * sub + 0];
984
+ x_u32[q_base + 2 * sub + 1] = src_qs[2 * sub + 1];
985
+ }
986
+
987
+ x_u32_scale[row_base] = get_int_b4(bxi->d, 0);
988
+ }
989
+ }
990
+
991
+ // Shared MMA kernel for MXFP4 and NVFP4 on Blackwell.
992
+ // Both quantizations encode values as e2m1 (FP4) and produce one uint32 scale per
993
+ // m16n8k64 MMA call; only the PTX kind (scale_vec::2X ue8m0 vs scale_vec::4X ue4m3)
994
+ // and the per-type stride constant differ.
995
+ template <int mmq_x, int mmq_y, ggml_type type>
996
+ static __device__ __forceinline__ void vec_dot_fp4_fp4_mma(const int * __restrict__ x,
997
+ const int * __restrict__ y,
998
+ float * __restrict__ sum,
999
+ const int k00) {
1000
+ static_assert(type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4,
1001
+ "vec_dot_fp4_fp4_mma: type must be MXFP4 or NVFP4");
1002
+
1003
+ typedef tile<16, 8, int> tile_A;
1004
+ typedef tile<8, 8, int> tile_B;
1005
+ typedef tile<16, 8, float> tile_C;
1006
+
1007
+ constexpr int stride = MMQ_MMA_TILE_X_K_FP4;
1008
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
1009
+ constexpr int rows_per_warp = 2 * granularity;
1010
+ constexpr int ntx = rows_per_warp / tile_C::I;
1011
+ constexpr int nfrags = MMQ_TILE_NE_K / tile_A::J;
1012
+
1013
+ y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_K);
1014
+
1015
+ const int * x_qs = (const int *) x;
1016
+ const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
1017
+ const int * y_qs = (const int *) y + 4;
1018
+ const uint32_t * y_sc = (const uint32_t *) y;
1019
+
1020
+ // 2 threads per quad supply the packed scale register to the block_scale MMA,
1021
+ // see https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1022
+ const int tidx_A = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
1023
+ const int tidx_B = threadIdx.x / 4;
1024
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1025
+
1026
+ tile_A A[ntx][nfrags];
1027
+ uint32_t scaleA[ntx][nfrags];
1028
+
1029
+ #pragma unroll
1030
+ for (int n = 0; n < ntx; ++n) {
1031
+ #pragma unroll
1032
+ for (int frag = 0; frag < nfrags; ++frag) {
1033
+ const int k0 = k00 + frag * tile_A::J;
1034
+ load_ldmatrix(A[n][frag], x_qs + (i0 + n * tile_A::I) * stride + k0, stride);
1035
+ scaleA[n][frag] = x_sc[(i0 + n * tile_A::I + tidx_A) * stride + k0 / tile_A::J];
1036
+ }
1037
+ }
1038
+
1039
+ #pragma unroll
1040
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
1041
+ tile_B B[nfrags];
1042
+ uint32_t scaleB[nfrags];
1043
+
1044
+ #pragma unroll
1045
+ for (int frag = 0; frag < nfrags; ++frag) {
1046
+ const int k0 = frag * tile_B::J;
1047
+ load_generic(B[frag], y_qs + j0 * MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K);
1048
+ scaleB[frag] = y_sc[(j0 + tidx_B) * MMQ_TILE_Y_K + frag];
1049
+ }
1050
+
1051
+ #pragma unroll
1052
+ for (int n = 0; n < ntx; ++n) {
1053
+ #pragma unroll
1054
+ for (int frag = 0; frag < nfrags; ++frag) {
1055
+ tile_C C = {};
1056
+ mma_block_scaled_fp4<type>(C, A[n][frag], B[frag], scaleA[n][frag], scaleB[frag]);
1057
+ #pragma unroll
1058
+ for (int l = 0; l < tile_C::ne; ++l) {
1059
+ sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
1060
+ }
1061
+ }
1062
+ }
1063
+ }
1064
+ }
1065
+ #endif // BLACKWELL_MMA_AVAILABLE
1066
+
1067
+
1068
+ template <int mmq_y, bool need_check>
1069
+ static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
1070
+ int * __restrict__ x_tile,
1071
+ const int kb0,
1072
+ const int i_max,
1073
+ const int stride) {
1074
+ constexpr int nwarps = mmq_get_nwarps_device();
1075
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1076
+
1077
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1078
+ int * x_qs = (int *) x_tile;
1079
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1080
+ #else
1081
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y);
1082
+ int * x_qs = (int *) x_tile;
1083
+ float * x_df = (float *) (x_qs + txs.qs);
1084
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1085
+
1086
+ constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4;
1087
+ constexpr int rows_per_warp = warp_size / threads_per_row;
1088
+ const int kbx = threadIdx.x % threads_per_row;
1089
+ const int row_in_warp = threadIdx.x / threads_per_row;
1090
+
1091
+ #pragma unroll
1092
+ for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
1093
+ int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
1094
+
1095
+ if constexpr (need_check) {
1096
+ i = min(i, i_max);
1097
+ }
1098
+
1099
+ const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx;
1100
+ const uint32_t * __restrict__ src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
1101
+ const int kqs = 16 * kbx;
1102
+ const int ksc = 4 * kbx;
1103
+
1104
+ #pragma unroll
1105
+ for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
1106
+ const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4);
1107
+ const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4);
1108
+
1109
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1110
+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x;
1111
+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x;
1112
+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y;
1113
+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y;
1114
+ x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
1115
+ #else
1116
+ x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x;
1117
+ x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x;
1118
+ x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y;
1119
+ x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y;
1120
+ x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
1121
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1122
+ }
1123
+ }
1124
+ }
1125
+
829
1126
  template <int mmq_x, int mmq_y>
830
1127
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
831
1128
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -887,13 +1184,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
887
1184
  tile_A A[ntx];
888
1185
  #pragma unroll
889
1186
  for (int n = 0; n < ntx; ++n) {
890
- load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
1187
+ load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
891
1188
  }
892
1189
 
893
1190
  #pragma unroll
894
1191
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
895
1192
  tile_B B;
896
- load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1193
+ load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
897
1194
 
898
1195
  float dB;
899
1196
  const int j = j0 + tile_C::get_j(0);
@@ -996,77 +1293,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
996
1293
  #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
997
1294
  }
998
1295
 
999
- template <int mmq_x, int mmq_y>
1000
- static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
1001
- const int * __restrict__ y,
1002
- float * __restrict__ sum,
1003
- const int k00) {
1004
- typedef tile<16, 8, int> tile_A;
1005
- typedef tile<8, 8, int> tile_B;
1006
- typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
1007
-
1008
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
1009
- constexpr int rows_per_warp = 2 * granularity;
1010
- constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
1011
-
1012
- y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
1013
-
1014
- // Match layout from load_tiles_mxfp4_fp4
1015
- const int * x_qs = (const int *) x;
1016
- const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
1017
- const int * y_qs = (const int *) y + 4;
1018
- const uint32_t * y_sc = (const uint32_t *) y;
1019
-
1020
- // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
1021
- tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1022
- uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
1023
-
1024
- // Block scale
1025
- // Each thread has to point to a 4 byte scale value
1026
- // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1027
-
1028
- const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1029
-
1030
- #pragma unroll
1031
- for (int n = 0; n < ntx; ++n) {
1032
- #pragma unroll
1033
- for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1034
- const int k0 = k00 + k01;
1035
-
1036
- load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
1037
- MMQ_MMA_TILE_X_K_FP4);
1038
-
1039
- // based on block-scaling document, 2 threads in each quad need to supply to the scale value
1040
- const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
1041
- scaleA[n][k01 / (2 * QI_MXFP4)] =
1042
- *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
1043
- }
1044
- }
1045
-
1046
- #pragma unroll
1047
- for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
1048
- #pragma unroll
1049
- for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
1050
- tile_B B;
1051
- uint32_t scaleB; // 2xN scales
1052
-
1053
- load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
1054
-
1055
- scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
1056
-
1057
- #pragma unroll
1058
- for (int n = 0; n < ntx; ++n) {
1059
- tile_C C;
1060
-
1061
- mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
1062
- #pragma unroll
1063
- for (int l = 0; l < tile_C::ne; ++l) {
1064
- sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
1065
- }
1066
- }
1067
- }
1068
- }
1069
- }
1070
1296
 
1071
1297
  template <int mmq_x, int mmq_y>
1072
1298
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
@@ -1128,13 +1354,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
1128
1354
  tile_A A[ntx];
1129
1355
  #pragma unroll
1130
1356
  for (int n = 0; n < ntx; ++n) {
1131
- load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
1357
+ load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
1132
1358
  }
1133
1359
 
1134
1360
  #pragma unroll
1135
1361
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1136
1362
  tile_B B;
1137
- load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1363
+ load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1138
1364
 
1139
1365
  const int j = j0 + tile_C::get_j(0);
1140
1366
  const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
@@ -1229,7 +1455,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
1229
1455
  #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1230
1456
  }
1231
1457
 
1232
- // Used for Q3_K, IQ2_S, and IQ2_XS
1458
+ // Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS
1233
1459
  template <int mmq_x, int mmq_y>
1234
1460
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
1235
1461
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -1268,57 +1494,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
1268
1494
  template <int mmq_x, int mmq_y>
1269
1495
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1270
1496
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1271
- #if defined(AMD_MFMA_AVAILABLE)
1272
- constexpr data_layout input_layout = get_input_data_layout();
1273
- typedef tile<16, 8, int, input_layout> tile_A;
1274
- typedef tile<16, 8, int, input_layout> tile_B;
1275
- typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1276
- typedef tile<64, 2, int, input_layout> tile_load;
1277
-
1278
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
1279
- constexpr int rows_per_warp = granularity;
1280
- constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1281
-
1282
- y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1283
-
1284
- const int * x_qs = (const int *) x;
1285
- const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1286
- const int * y_qs = (const int *) y + 4;
1287
- const float * y_df = (const float *) y;
1288
-
1289
- const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1290
-
1291
- for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1292
- const int k0 = k00 + k01;
1293
-
1294
- tile_A A[ntx];
1295
- #pragma unroll
1296
- for (int n = 0; n < ntx; ++n) {
1297
- load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1298
- }
1299
-
1300
- #pragma unroll
1301
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1302
- tile_B B[1];
1303
- load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1304
-
1305
- const int j = j0 + tile_C::get_j(0);
1306
- const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
1307
-
1308
- #pragma unroll
1309
- for (int n = 0; n < ntx; ++n) {
1310
- tile_C C;
1311
- mma(C, A[n], B[0]);
1312
-
1313
- #pragma unroll
1314
- for (int l = 0; l < tile_C::ne; ++l) {
1315
- const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1316
- sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
1317
- }
1318
- }
1319
- }
1320
- }
1321
- #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1497
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1322
1498
  constexpr data_layout input_layout = get_input_data_layout();
1323
1499
  typedef tile<16, 4, int, input_layout> tile_A;
1324
1500
  typedef tile<16, 4, int, input_layout> tile_B;
@@ -1343,13 +1519,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1343
1519
  tile_A A[ntx];
1344
1520
  #pragma unroll
1345
1521
  for (int n = 0; n < ntx; ++n) {
1346
- load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1522
+ load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1347
1523
  }
1348
1524
 
1349
1525
  #pragma unroll
1350
1526
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1351
1527
  tile_B B;
1352
- load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1528
+ load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1353
1529
 
1354
1530
  const int j = j0 + tile_C::get_j(0);
1355
1531
  const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
@@ -1575,74 +1751,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1575
1751
  template <int mmq_x, int mmq_y>
1576
1752
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1577
1753
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1578
- #if defined(AMD_MFMA_AVAILABLE)
1579
- constexpr data_layout input_layout = get_input_data_layout();
1580
- typedef tile<16, 8, int, input_layout> tile_A;
1581
- typedef tile<16, 8, int, input_layout> tile_B;
1582
- typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
1583
- typedef tile<64, 2, int, input_layout> tile_load;
1584
-
1585
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
1586
- constexpr int rows_per_warp = granularity;
1587
- constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1588
-
1589
- y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1590
-
1591
- const int * x_qs = (const int *) x;
1592
- const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1593
- const int * y_qs = (const int *) y + 4;
1594
- const half2 * y_ds = (const half2 *) y;
1595
-
1596
- const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1597
-
1598
- for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1599
- const int k0 = k00 + k01;
1600
-
1601
- tile_A A[ntx];
1602
- #pragma unroll
1603
- for (int n = 0; n < ntx; ++n) {
1604
- load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1605
- }
1606
-
1607
- #pragma unroll
1608
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1609
- tile_B B[1];
1610
- load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1611
-
1612
- const int j = j0 + tile_C::get_j(0);
1613
- const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
1614
- const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
1615
- : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1616
- : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
1617
-
1618
- tile_C Cm;
1619
- if (k01 >= MMQ_TILE_NE_K * 3/4) {
1620
- tile_A A1;
1621
- A1.x[0] = 0x01010101;
1622
- A1.x[1] = 0x01010101;
1623
- mma(Cm, A1, B[0]);
1624
- }
1625
-
1626
- #pragma unroll
1627
- for (int n = 0; n < ntx; ++n) {
1628
- tile_C Cd;
1629
- mma(Cd, A[n], B[0]);
1630
-
1631
- #pragma unroll
1632
- for (int l = 0; l < tile_C::ne; ++l) {
1633
- const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1634
- const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
1635
- float tmp = Cd.x[l]*dm.x;
1636
- if (k01 >= MMQ_TILE_NE_K * 3/4) {
1637
- tmp -= Cm.x[l]*dm.y;
1638
- }
1639
- sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
1640
- sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
1641
- }
1642
- }
1643
- }
1644
- }
1645
- #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
1754
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
1646
1755
  constexpr data_layout input_layout = get_input_data_layout();
1647
1756
  typedef tile<16, 4, int, input_layout> tile_A;
1648
1757
  typedef tile<16, 4, int, input_layout> tile_B;
@@ -1667,13 +1776,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1667
1776
  tile_A A[ntx];
1668
1777
  #pragma unroll
1669
1778
  for (int n = 0; n < ntx; ++n) {
1670
- load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1779
+ load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1671
1780
  }
1672
1781
 
1673
1782
  #pragma unroll
1674
1783
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1675
1784
  tile_B B;
1676
- load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1785
+ load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1677
1786
 
1678
1787
  const int j = j0 + tile_C::get_j(0);
1679
1788
  const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
@@ -2406,59 +2515,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
2406
2515
  template <int mmq_x, int mmq_y>
2407
2516
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2408
2517
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2409
- #if defined(AMD_MFMA_AVAILABLE)
2410
- constexpr data_layout input_layout = get_input_data_layout();
2411
- typedef tile<16, 8, int, input_layout> tile_A;
2412
- typedef tile<16, 8, int, input_layout> tile_B;
2413
- typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
2414
- typedef tile<64, 2, int, input_layout> tile_load;
2415
-
2416
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
2417
- constexpr int rows_per_warp = granularity;
2418
- constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2419
-
2420
- y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2421
-
2422
- const int * x_qs = (const int *) x;
2423
- const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2424
- const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2425
- const int * y_qs = (const int *) y + 4;
2426
- const float * y_df = (const float *) y;
2427
-
2428
- const int i0 = (threadIdx.y / ntx) * rows_per_warp;
2429
-
2430
- for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2431
- const int k0 = k00 + k01;
2432
-
2433
- tile_A A[ntx];
2434
- #pragma unroll
2435
- for (int n = 0; n < ntx; ++n) {
2436
- load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2437
- }
2438
-
2439
- #pragma unroll
2440
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2441
- tile_B B[1];
2442
- load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2443
-
2444
- const int j = j0 + tile_C::get_j(0);
2445
- const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
2446
-
2447
- #pragma unroll
2448
- for (int n = 0; n < ntx; ++n) {
2449
- tile_C C;
2450
- mma(C, A[n], B[0]);
2451
-
2452
- #pragma unroll
2453
- for (int l = 0; l < tile_C::ne; ++l) {
2454
- const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2455
- const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2456
- sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2457
- }
2458
- }
2459
- }
2460
- }
2461
- #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
2518
+ #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2462
2519
  constexpr data_layout input_layout = get_input_data_layout();
2463
2520
  typedef tile<16, 4, int, input_layout> tile_A;
2464
2521
  typedef tile<16, 4, int, input_layout> tile_B;
@@ -2484,13 +2541,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2484
2541
  tile_A A[ntx];
2485
2542
  #pragma unroll
2486
2543
  for (int n = 0; n < ntx; ++n) {
2487
- load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2544
+ load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2488
2545
  }
2489
2546
 
2490
2547
  #pragma unroll
2491
2548
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2492
2549
  tile_B B;
2493
- load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2550
+ load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2494
2551
 
2495
2552
  const int j = j0 + tile_C::get_j(0);
2496
2553
  const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
@@ -2715,14 +2772,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2715
2772
 
2716
2773
  #pragma unroll
2717
2774
  for (int l = 0; l < QR2_XXS; ++l) {
2718
- const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
2719
- const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
2775
+ const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];
2776
+ const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));
2720
2777
 
2721
- const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
2722
- const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
2778
+ const int signs0 = __vcmpne4(signs & 0x08040201, 0);
2779
+ const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
2723
2780
 
2724
- const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
2725
- const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
2781
+ const int signs1 = __vcmpne4(signs & 0x80402010, 0);
2782
+ const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
2726
2783
 
2727
2784
  #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2728
2785
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
@@ -2733,12 +2790,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2733
2790
  #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2734
2791
  }
2735
2792
 
2736
- const int ls = aux32 >> 28;
2793
+ const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
2737
2794
  const float d = bxi->d;
2738
2795
  #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2739
- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
2796
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
2740
2797
  #else
2741
- x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
2798
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
2742
2799
  #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2743
2800
  }
2744
2801
  }
@@ -2776,11 +2833,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2776
2833
 
2777
2834
  #pragma unroll
2778
2835
  for (int l = 0; l < QR2_XS; ++l) {
2779
- const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
2780
- const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
2836
+ const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];
2837
+ const uint32_t signs = unpack_ksigns(q2[l] >> 9);
2781
2838
 
2782
- const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
2783
- const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
2839
+ const int signs0 = __vcmpne4(signs & 0x08040201, 0);
2840
+ const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2841
+
2842
+ const int signs1 = __vcmpne4(signs & 0x80402010, 0);
2843
+ const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2784
2844
 
2785
2845
  #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2786
2846
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
@@ -2904,11 +2964,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2904
2964
  #pragma unroll
2905
2965
  for (int l = 0; l < QR3_XXS; ++l) {
2906
2966
  const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
2967
+ const uint32_t signs = unpack_ksigns(aux32 >> (7*l));
2907
2968
 
2908
- const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
2969
+ const int signs0 = __vcmpne4(signs & 0x08040201, 0);
2970
+ const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2909
2971
 
2910
- const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
2911
- const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
2972
+ const int signs1 = __vcmpne4(signs & 0x80402010, 0);
2973
+ const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2912
2974
 
2913
2975
  #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
2914
2976
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
@@ -3203,6 +3265,14 @@ static __device__ __forceinline__ void mmq_write_back_mma(
3203
3265
  template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
3204
3266
  struct mmq_type_traits;
3205
3267
 
3268
+ template <int mmq_x, int mmq_y, bool need_check>
3269
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> {
3270
+ static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ;
3271
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0<mmq_y, need_check>;
3272
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3273
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3274
+ };
3275
+
3206
3276
  template <int mmq_x, int mmq_y, bool need_check>
3207
3277
  struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
3208
3278
  static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
@@ -3248,7 +3318,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
3248
3318
  static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
3249
3319
  #ifdef BLACKWELL_MMA_AVAILABLE
3250
3320
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
3251
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
3321
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_MXFP4>;
3252
3322
  #else
3253
3323
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
3254
3324
  static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
@@ -3256,6 +3326,19 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
3256
3326
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3257
3327
  };
3258
3328
 
3329
+ template <int mmq_x, int mmq_y, bool need_check>
3330
+ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
3331
+ static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
3332
+ #ifdef BLACKWELL_MMA_AVAILABLE
3333
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4_nvfp4<mmq_y, need_check>;
3334
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_NVFP4>;
3335
+ #else
3336
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
3337
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3338
+ #endif // BLACKWELL_MMA_AVAILABLE
3339
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
3340
+ };
3341
+
3259
3342
  template <int mmq_x, int mmq_y, bool need_check>
3260
3343
  struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
3261
3344
  static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
@@ -3387,7 +3470,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
3387
3470
 
3388
3471
  #if defined(BLACKWELL_MMA_AVAILABLE)
3389
3472
  // FP4 tile stores 8 blocks
3390
- constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
3473
+ constexpr int ne_block = (type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4) ? QK_K : 4 * QK8_1;
3391
3474
  #else
3392
3475
  constexpr int ne_block = 4 * QK8_1;
3393
3476
  #endif // defined(BLACKWELL_MMA_AVAILABLE)
@@ -3459,10 +3542,10 @@ template <ggml_type type, int mmq_x, bool need_check>
3459
3542
  static __global__ void mul_mat_q(
3460
3543
  const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
3461
3544
  const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
3462
- const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
3463
- const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
3464
- const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
3465
- const int ncols_max) {
3545
+ const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
3546
+ const uint3 channel_ratio, const uint3 nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
3547
+ const uint3 sample_ratio, const uint3 nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
3548
+ const uint3 ntx) {
3466
3549
 
3467
3550
  // Skip unused template specializations for faster compilation:
3468
3551
  if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
@@ -3476,8 +3559,7 @@ static __global__ void mul_mat_q(
3476
3559
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
3477
3560
  constexpr int mmq_y = get_mmq_y_device();
3478
3561
 
3479
- const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
3480
- const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
3562
+ const uint32_t nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
3481
3563
 
3482
3564
  // Initialize the ids for writing back data with just the index.
3483
3565
  // For regular matrix multiplications this is never changed.
@@ -3498,8 +3580,9 @@ static __global__ void mul_mat_q(
3498
3580
  // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
3499
3581
  #if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3500
3582
  {
3501
- const int wt = blockIdx.z / nchannels_y;
3502
- const int zt = blockIdx.z - wt*nchannels_y;
3583
+ const uint2 tmp2 = fast_div_modulo(blockIdx.z, nchannels_y);
3584
+ const int wt = tmp2.x;
3585
+ const int zt = tmp2.y;
3503
3586
  const int jt = blockIdx.y;
3504
3587
  const int it = blockIdx.x;
3505
3588
 
@@ -3542,40 +3625,40 @@ static __global__ void mul_mat_q(
3542
3625
  const int tile_x_max_i = nrows_x - it*mmq_y - 1;
3543
3626
  const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3544
3627
 
3545
- const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3628
+ const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3546
3629
 
3547
3630
  constexpr bool fixup = false;
3548
3631
  mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
3549
3632
  (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
3550
- tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
3633
+ tile_x_max_i, tile_y_max_j, 0, blocks_per_ne00.z);
3551
3634
  return;
3552
3635
  }
3553
- #endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3636
+ #endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3554
3637
 
3555
- constexpr int ITER_K = get_iter_k(type);
3556
-
3557
- const int64_t blocks_per_ne00 = ncols_x / qk;
3558
- constexpr int blocks_per_iter = ITER_K / qk;
3638
+ constexpr int ITER_K = get_iter_k(type);
3639
+ constexpr int blocks_per_iter = ITER_K / qk;
3559
3640
 
3560
3641
  // kbc == k block continuous, current index in continuous ijk space.
3561
- int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3562
- int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3642
+ int kbc = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
3643
+ int kbc_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
3563
3644
 
3564
- kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
3565
- kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
3645
+ kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter;
3646
+ kbc_stop -= fastmodulo(kbc_stop, blocks_per_ne00) % blocks_per_iter;
3566
3647
 
3567
3648
  // kb0 == k index when doing the matrix multiplication for an output tile.
3568
- int kb0_start = kbc % blocks_per_ne00;
3569
- int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
3570
- while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
3571
- int tmp = kbc;
3572
- const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3573
- tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3574
- const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3575
- tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3576
- const int zt = tmp / (ntx*blocks_per_ne00);
3577
- tmp -= zt * (ntx*blocks_per_ne00);
3578
- const int jt = tmp / blocks_per_ne00;
3649
+ int kb0_start = fastmodulo(kbc, blocks_per_ne00);
3650
+ int kb0_stop = min(blocks_per_ne00.z, uint32_t(kb0_start + kbc_stop - kbc));
3651
+ while (kbc < kbc_stop && kb0_stop == int(blocks_per_ne00.z)) {
3652
+ int tmp = fastdiv(kbc, blocks_per_ne00);
3653
+ uint2 tmp2 = fast_div_modulo(tmp, ntx);
3654
+ const int jt = tmp2.y;
3655
+ tmp = tmp2.x;
3656
+ tmp2 = fast_div_modulo(tmp, nchannels_y);
3657
+ const int zt = tmp2.y;
3658
+ tmp = tmp2.x;
3659
+ tmp2 = fast_div_modulo(tmp, nsamples_y);
3660
+ const int wt = tmp2.y;
3661
+ const int it = tmp2.x;
3579
3662
 
3580
3663
  // Defaults for regular matrix multiplication:
3581
3664
  int col_low = 0;
@@ -3593,11 +3676,11 @@ static __global__ void mul_mat_q(
3593
3676
  offset_dst = 0;
3594
3677
 
3595
3678
  if (jt*mmq_x >= col_diff) {
3596
- kbc += blocks_per_ne00;
3597
- kbc -= kbc % blocks_per_ne00;
3679
+ kbc += blocks_per_ne00.z;
3680
+ kbc -= fastmodulo(kbc, blocks_per_ne00);
3598
3681
 
3599
3682
  kb0_start = 0;
3600
- kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
3683
+ kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc));
3601
3684
 
3602
3685
  continue;
3603
3686
  }
@@ -3622,32 +3705,34 @@ static __global__ void mul_mat_q(
3622
3705
  const int tile_x_max_i = nrows_x - it*mmq_y - 1;
3623
3706
  const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3624
3707
 
3625
- const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3708
+ const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3626
3709
 
3627
3710
  constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
3628
3711
  mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
3629
3712
  (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
3630
3713
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
3631
3714
 
3632
- kbc += blocks_per_ne00;
3633
- kbc -= kbc % blocks_per_ne00;
3715
+ kbc += blocks_per_ne00.z;
3716
+ kbc -= fastmodulo(kbc, blocks_per_ne00);
3634
3717
 
3635
3718
  kb0_start = 0;
3636
- kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
3719
+ kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc));
3637
3720
  }
3638
3721
 
3639
3722
  if (kbc >= kbc_stop) {
3640
3723
  return;
3641
3724
  }
3642
3725
 
3643
- int tmp = kbc;
3644
- const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3645
- tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3646
- const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3647
- tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3648
- const int zt = tmp / (ntx*blocks_per_ne00);
3649
- tmp -= zt * (ntx*blocks_per_ne00);
3650
- const int jt = tmp / blocks_per_ne00;
3726
+ int tmp = fastdiv(kbc, blocks_per_ne00);
3727
+ uint2 tmp2 = fast_div_modulo(tmp, ntx);
3728
+ const int jt = tmp2.y;
3729
+ tmp = tmp2.x;
3730
+ tmp2 = fast_div_modulo(tmp, nchannels_y);
3731
+ const int zt = tmp2.y;
3732
+ tmp = tmp2.x;
3733
+ tmp2 = fast_div_modulo(tmp, nsamples_y);
3734
+ const int wt = tmp2.y;
3735
+ const int it = tmp2.x;
3651
3736
 
3652
3737
  // Defaults for regular matrix multiplication:
3653
3738
  int col_low = 0;
@@ -3689,7 +3774,7 @@ static __global__ void mul_mat_q(
3689
3774
  const int tile_x_max_i = nrows_x - it*mmq_y - 1;
3690
3775
  const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3691
3776
 
3692
- const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3777
+ const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3693
3778
 
3694
3779
  constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
3695
3780
  mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
@@ -3697,40 +3782,38 @@ static __global__ void mul_mat_q(
3697
3782
  tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
3698
3783
  }
3699
3784
 
3700
-
3701
3785
  template <ggml_type type, int mmq_x, bool need_check>
3786
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device()/2, 1)
3702
3787
  static __global__ void mul_mat_q_stream_k_fixup(
3703
- const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
3704
- const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
3705
- const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
3706
- const int ncols_max) {
3707
- constexpr int mmq_y = get_mmq_y_device();
3708
- constexpr int qk = ggml_cuda_type_traits<type>::qk;
3709
- constexpr int ITER_K = get_iter_k(type);
3710
-
3711
- constexpr int blocks_per_iter = ITER_K / qk;
3712
- const int64_t blocks_per_ne00 = ncols_x / qk;
3788
+ const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
3789
+ float * __restrict__ tmp_last_tile, const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst,
3790
+ const int stride_col_dst, const uint3 nchannels_y, const int stride_channel_dst, const uint3 nsamples_y,
3791
+ const int stride_sample_dst, const uint3 ntx) {
3792
+ constexpr int mmq_y = get_mmq_y_device();
3793
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
3794
+ constexpr int ITER_K = get_iter_k(type);
3795
+ constexpr int blocks_per_iter = ITER_K / qk;
3713
3796
 
3714
- constexpr int nwarps = mmq_get_nwarps_device();
3797
+ constexpr int nwarps = mmq_get_nwarps_device()/2;
3715
3798
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3716
3799
 
3717
- float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
3800
+ float sum[mmq_x / nwarps] = {0.0f};
3801
+ const int i = blockIdx.y*warp_size + threadIdx.x;
3718
3802
 
3719
- const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
3720
- const int nty = (nrows_x + mmq_y - 1) / mmq_y;
3803
+ const int nty = (nrows_x + mmq_y - 1) / mmq_y;
3721
3804
 
3722
3805
  const int bidx0 = blockIdx.x;
3723
3806
 
3724
3807
  // kbc == k block continuous, current index in continuous ijk space.
3725
- int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3726
- int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3808
+ int kbc0 = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
3809
+ int kbc0_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
3727
3810
 
3728
- kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter;
3729
- kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter;
3811
+ kbc0 -= fastmodulo(kbc0, blocks_per_ne00) % blocks_per_iter;
3812
+ kbc0_stop -= fastmodulo(kbc0_stop, blocks_per_ne00) % blocks_per_iter;
3730
3813
 
3731
3814
  const bool did_not_have_any_data = kbc0 == kbc0_stop;
3732
- const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0;
3733
- const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0;
3815
+ const bool wrote_beginning_of_tile = fastmodulo(kbc0, blocks_per_ne00) == 0;
3816
+ const bool did_not_write_last = fastdiv(kbc0, blocks_per_ne00) == fastdiv(kbc0_stop, blocks_per_ne00) && fastmodulo(kbc0_stop, blocks_per_ne00) != 0;
3734
3817
  if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
3735
3818
  return;
3736
3819
  }
@@ -3739,11 +3822,11 @@ static __global__ void mul_mat_q_stream_k_fixup(
3739
3822
 
3740
3823
  // Iterate over previous blocks and sum up partial sums written to fixup buffer.
3741
3824
  // All CUDA blocks that get here must have a previous block that needs a fixup.
3742
- int64_t bidx = bidx0 - 1;
3743
- int64_t kbc_stop = kbc0;
3825
+ int bidx = bidx0 - 1;
3826
+ int kbc_stop = kbc0;
3744
3827
  while(true) {
3745
- int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3746
- kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
3828
+ int kbc = int64_t(bidx)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x;
3829
+ kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter;
3747
3830
 
3748
3831
  if (kbc == kbc_stop) { // Did not have any data.
3749
3832
  bidx--;
@@ -3753,20 +3836,16 @@ static __global__ void mul_mat_q_stream_k_fixup(
3753
3836
 
3754
3837
  any_fixup = true;
3755
3838
 
3839
+
3756
3840
  #pragma unroll
3757
3841
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
3758
3842
  const int j = j0 + threadIdx.y;
3759
3843
 
3760
- #pragma unroll
3761
- for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3762
- const int i = i0 + threadIdx.x;
3763
-
3764
- sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
3765
- }
3844
+ sum[j0/nwarps] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
3766
3845
  }
3767
3846
 
3768
3847
  // If this block started in a previous tile we are done and don't need to combine additional partial results.
3769
- if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) {
3848
+ if (fastmodulo(kbc, blocks_per_ne00) == 0 || fastdiv(kbc, blocks_per_ne00) < fastdiv(kbc0, blocks_per_ne00)) {
3770
3849
  break;
3771
3850
  }
3772
3851
  bidx--;
@@ -3777,14 +3856,16 @@ static __global__ void mul_mat_q_stream_k_fixup(
3777
3856
  return;
3778
3857
  }
3779
3858
 
3780
- int tmp = kbc0;
3781
- const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3782
- tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3783
- const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3784
- tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3785
- const int zt = tmp / (ntx*blocks_per_ne00);
3786
- tmp -= zt * (ntx*blocks_per_ne00);
3787
- const int jt = tmp / blocks_per_ne00;
3859
+ int tmp = fastdiv(kbc0, blocks_per_ne00);
3860
+ uint2 tmp2 = fast_div_modulo(tmp, ntx);
3861
+ const int jt = tmp2.y;
3862
+ tmp = tmp2.x;
3863
+ tmp2 = fast_div_modulo(tmp, nchannels_y);
3864
+ const int zt = tmp2.y;
3865
+ tmp = tmp2.x;
3866
+ tmp2 = fast_div_modulo(tmp, nsamples_y);
3867
+ const int wt = tmp2.y;
3868
+ const int it = tmp2.x;
3788
3869
 
3789
3870
  if (!ids_dst) {
3790
3871
  const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
@@ -3792,6 +3873,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
3792
3873
 
3793
3874
  const int i_max = nrows_x - it*mmq_y - 1;
3794
3875
  const int j_max = ncols_dst - jt*mmq_x - 1;
3876
+ if (need_check && i > i_max) {
3877
+ return;
3878
+ }
3795
3879
 
3796
3880
  #pragma unroll
3797
3881
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -3801,16 +3885,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
3801
3885
  return;
3802
3886
  }
3803
3887
 
3804
- #pragma unroll
3805
- for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3806
- const int i = i0 + threadIdx.x;
3807
-
3808
- if (need_check && i > i_max) {
3809
- continue;
3810
- }
3811
-
3812
- dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
3813
- }
3888
+ dst[j*stride_col_dst + i] += sum[j0/nwarps];
3814
3889
  }
3815
3890
  return;
3816
3891
  }
@@ -3830,6 +3905,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
3830
3905
 
3831
3906
  const int i_max = nrows_x - it*mmq_y - 1;
3832
3907
  const int j_max = col_diff - jt*mmq_x - 1;
3908
+ if (need_check && i > i_max) {
3909
+ return;
3910
+ }
3833
3911
 
3834
3912
  #pragma unroll
3835
3913
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -3839,16 +3917,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
3839
3917
  return;
3840
3918
  }
3841
3919
 
3842
- #pragma unroll
3843
- for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3844
- const int i = i0 + threadIdx.x;
3845
-
3846
- if (need_check && i > i_max) {
3847
- continue;
3848
- }
3849
-
3850
- dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
3851
- }
3920
+ dst[ids_dst_shared[j]*stride_col_dst + i] += sum[j0/nwarps];
3852
3921
  }
3853
3922
  }
3854
3923
 
@@ -3896,29 +3965,44 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3896
3965
  const int channel_ratio = args.nchannels_y / args.nchannels_x;
3897
3966
  const int sample_ratio = args.nsamples_y / args.nsamples_x;
3898
3967
 
3968
+ const uint3 blocks_per_ne00_fd = init_fastdiv_values(args.ncols_x / ggml_cuda_type_traits<type>::qk);
3969
+ const uint3 ntx_fd = init_fastdiv_values(ntx);
3970
+ const uint3 nchannels_y_fd = init_fastdiv_values(args.nchannels_y);
3971
+ const uint3 nsamples_y_fd = init_fastdiv_values(args.nsamples_y);
3972
+ const uint3 channel_ratio_fd = init_fastdiv_values(channel_ratio);
3973
+ const uint3 sample_ratio_fd = init_fastdiv_values(sample_ratio);
3974
+
3899
3975
  if (!args.use_stream_k) {
3900
3976
  if (args.nrows_x % mmq_y == 0) {
3901
3977
  constexpr bool need_check = false;
3902
3978
  mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3903
3979
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3904
- args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3905
- channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3906
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3907
- args.ncols_max);
3980
+ blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3981
+ channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3982
+ sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3983
+ ntx_fd);
3908
3984
  } else {
3909
3985
  constexpr bool need_check = true;
3910
3986
  mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3911
3987
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3912
- args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3913
- channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3914
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3915
- args.ncols_max);
3988
+ blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3989
+ channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3990
+ sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3991
+ ntx_fd);
3916
3992
  }
3917
3993
  return;
3918
3994
  }
3919
3995
 
3920
- const dim3 block_nums_stream_k(nsm, 1, 1);
3921
- const bool fixup_needed = ntx*nty*ntzw % nsm != 0;
3996
+ // For the stream-k kernel it is possible to run it with tiling by setting the number of CUDA blocks equal to the number of tiles.
3997
+ // This is worthwhile if the efficiency of tiling is high and skipping the fixup kernel is more important.
3998
+ const int ntiles_dst = ntx * nty * ntzw;
3999
+ const int tiles_nwaves = (ntiles_dst + nsm - 1) / nsm;
4000
+ const int tiles_efficiency_percent = 100 * ntiles_dst / (nsm*tiles_nwaves);
4001
+ const dim3 block_nums_stream_k(GGML_CUDA_CC_IS_NVIDIA(cc) && tiles_efficiency_percent >= 90 ? ntiles_dst : nsm, 1, 1);
4002
+
4003
+ GGML_ASSERT(ntiles_dst * blocks_per_ne00_fd.z < (1 << 30)); // Assert that variable kbc will not overflow.
4004
+
4005
+ const bool fixup_needed = ntiles_dst % block_nums_stream_k.x != 0;
3922
4006
 
3923
4007
  ggml_cuda_pool & pool = ctx.pool(id);
3924
4008
  ggml_cuda_pool_alloc<float> tmp_fixup(pool);
@@ -3926,40 +4010,45 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3926
4010
  tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);
3927
4011
  }
3928
4012
 
4013
+ const dim3 block_nums_fixup(block_nums_stream_k.x, mmq_y/warp_size, 1);
4014
+ const dim3 block_dims_fixup(block_dims.x, block_dims.y/2, block_dims.z);
4015
+
3929
4016
  if (args.nrows_x % mmq_y == 0) {
3930
4017
  constexpr bool need_check = false;
3931
4018
  mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3932
4019
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3933
- args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3934
- channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3935
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3936
- args.ncols_max);
4020
+ blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
4021
+ channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
4022
+ sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
4023
+ ntx_fd);
3937
4024
 
3938
4025
  if (!fixup_needed) {
3939
4026
  return;
3940
4027
  }
3941
4028
 
3942
- mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3943
- (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3944
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3945
- args.ncols_max);
4029
+ CUDA_CHECK(cudaGetLastError());
4030
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>>
4031
+ (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst,
4032
+ args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst,
4033
+ ntx_fd);
3946
4034
  } else {
3947
4035
  constexpr bool need_check = true;
3948
4036
  mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3949
4037
  (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3950
- args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3951
- channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3952
- sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3953
- args.ncols_max);
4038
+ blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
4039
+ channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
4040
+ sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
4041
+ ntx_fd);
3954
4042
 
3955
4043
  if (!fixup_needed) {
3956
4044
  return;
3957
4045
  }
3958
4046
 
3959
- mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3960
- (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3961
- args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3962
- args.ncols_max);
4047
+ CUDA_CHECK(cudaGetLastError());
4048
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>>
4049
+ (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst,
4050
+ args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst,
4051
+ ntx_fd);
3963
4052
  }
3964
4053
  }
3965
4054
 
@@ -4057,6 +4146,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
4057
4146
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
4058
4147
  extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
4059
4148
  extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
4149
+ extern DECL_MMQ_CASE(GGML_TYPE_NVFP4);
4060
4150
  extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
4061
4151
  extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
4062
4152
  extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
@@ -4083,3 +4173,4 @@ void ggml_cuda_op_mul_mat_q(
4083
4173
  const int64_t src1_padded_row_size, cudaStream_t stream);
4084
4174
 
4085
4175
  bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);
4176
+